-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathmodel.py
More file actions
91 lines (78 loc) · 3.24 KB
/
model.py
File metadata and controls
91 lines (78 loc) · 3.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import SAGEConv, GraphConv, GATConv
import dgl.function as fn
class GraphSAGE(nn.Module):
def __init__(self, in_feats, h_feats, dropout = 0.5):
super(GraphSAGE, self).__init__()
self.conv1 = SAGEConv(in_feats, h_feats, 'mean')
self.dropout = nn.Dropout(dropout)
self.conv2 = SAGEConv(h_feats, h_feats, 'mean')
def forward(self, g, in_feat):
h = self.conv1(g, in_feat)
h = self.dropout(h) #dropout before relu
h = F.relu(h)
h = self.conv2(g, h)
return h
class GCN(nn.Module):
def __init__(self, in_feats, h_feats, dropout = 0.5):
super(GCN, self).__init__()
self.conv1 = GraphConv(in_feats, h_feats, allow_zero_in_degree = True)
self.dropout = nn.Dropout(dropout)
self.conv2 = GraphConv(h_feats, h_feats, allow_zero_in_degree = True)
def forward(self, g, in_feat):
h = self.conv1(g, in_feat)
h = self.dropout(h) #dropout before relu
h = F.relu(h)
h = self.conv2(g, h)
return h
class GAT(nn.Module):
def __init__(self, in_feats, h_feats, dropout = 0.5):
super(GAT, self).__init__()
self.conv1 = GATConv(in_feats, h_feats, 1, allow_zero_in_degree = True)
self.dropout = nn.Dropout(dropout)
self.conv2 = GATConv(h_feats, h_feats, 1, allow_zero_in_degree = True)
def forward(self, g, in_feat):
h = self.conv1(g, in_feat)
h = h.reshape(h.shape[0],h.shape[2]) #1 attention head
h = self.dropout(h) #dropout before relu
h = F.relu(h)
h = self.conv2(g, h)
h = h.reshape(h.shape[0],h.shape[2]) #1 attention head
return h
class MLPPredictor(nn.Module):
def __init__(self, h_feats):
super().__init__()
self.W1 = nn.Linear(h_feats * 2, h_feats)
self.W2 = nn.Linear(h_feats, 1)
self.sig = nn.Sigmoid()
def apply_edges(self, edges):
h = torch.cat([edges.src['h'], edges.dst['h']], 1)
return {'score': self.sig(self.W2(F.relu(self.W1(h)))).squeeze(1)}
def forward(self, g, h):
with g.local_scope():
g.ndata['h'] = h
g.apply_edges(self.apply_edges)
return g.edata['score']
class HeteroDotProductPredictor(nn.Module):
def forward(self, graph, h, etype):
with graph.local_scope():
graph.ndata['h'] = h # assigns 'h' of all node types in one shot
graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
return graph.edges[etype].data['score']
class HeteroMLPPredictor(nn.Module):
def __init__(self, h_feats, edge_types, dropout = 0.5):
super().__init__()
self.W1 = nn.Linear(h_feats * 2, h_feats)
self.dropout = nn.Dropout(dropout)
self.W2 = nn.Linear(h_feats, edge_types)
self.sig = nn.Sigmoid()
def apply_edges(self, edges):
h = torch.cat([edges.src['h'], edges.dst['h']], 1)
return {'score': self.sig(self.W2(F.relu(self.dropout(self.W1(h)))))} # dim: edge_types
def forward(self, g, h):
with g.local_scope():
g.ndata['h'] = h
g.apply_edges(self.apply_edges)
return g.edata['score']