-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathmodel.py
More file actions
117 lines (86 loc) · 4.3 KB
/
model.py
File metadata and controls
117 lines (86 loc) · 4.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
torch.autograd.set_detect_anomaly(True)
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool, global_max_pool, GlobalAttention
class WiKG(nn.Module):
def __init__(self, dim_in=384, dim_hidden=512, topk=6, n_classes=2, agg_type='bi-interaction', dropout=0.3, pool='attn'):
super().__init__()
self._fc1 = nn.Sequential(nn.Linear(dim_in, dim_hidden), nn.LeakyReLU())
self.W_head = nn.Linear(dim_hidden, dim_hidden)
self.W_tail = nn.Linear(dim_hidden, dim_hidden)
self.scale = dim_hidden ** -0.5
self.topk = topk
self.agg_type = agg_type
self.gate_U = nn.Linear(dim_hidden, dim_hidden // 2)
self.gate_V = nn.Linear(dim_hidden, dim_hidden // 2)
self.gate_W = nn.Linear(dim_hidden // 2, dim_hidden)
if self.agg_type == 'gcn':
self.linear = nn.Linear(dim_hidden, dim_hidden)
elif self.agg_type == 'sage':
self.linear = nn.Linear(dim_hidden * 2, dim_hidden)
elif self.agg_type == 'bi-interaction':
self.linear1 = nn.Linear(dim_hidden, dim_hidden)
self.linear2 = nn.Linear(dim_hidden, dim_hidden)
else:
raise NotImplementedError
self.activation = nn.LeakyReLU()
self.message_dropout = nn.Dropout(dropout)
self.norm = nn.LayerNorm(dim_hidden)
self.fc = nn.Linear(dim_hidden, n_classes)
if pool == "mean":
self.readout = global_mean_pool
elif pool == "max":
self.readout = global_max_pool
elif pool == "attn":
att_net=nn.Sequential(nn.Linear(dim_hidden, dim_hidden // 2), nn.LeakyReLU(), nn.Linear(dim_hidden//2, 1))
self.readout = GlobalAttention(att_net)
def forward(self, x):
try:
x = x["feature"]
except:
x = x
x = self._fc1(x) # [B,N,C]
# B, N, C = x.shape
x = (x + x.mean(dim=1, keepdim=True)) * 0.5
e_h = self.W_head(x)
e_t = self.W_tail(x)
# construct neighbour
attn_logit = (e_h * self.scale) @ e_t.transpose(-2, -1) # 1
topk_weight, topk_index = torch.topk(attn_logit, k=self.topk, dim=-1)
# add an extra dimension to the index tensor, making it available for advanced indexing, aligned with the dimensions of e_t
topk_index = topk_index.to(torch.long)
# expand topk_index dimensions to match e_t
topk_index_expanded = topk_index.expand(e_t.size(0), -1, -1) # shape: [1, 10000, 4]
# create a RANGE tensor to help indexing
batch_indices = torch.arange(e_t.size(0)).view(-1, 1, 1).to(topk_index.device) # shape: [1, 1, 1]
Nb_h = e_t[batch_indices, topk_index_expanded, :] # shape: [1, 10000, 4, 512]
# use SoftMax to obtain probability
topk_prob = F.softmax(topk_weight, dim=2)
eh_r = torch.mul(topk_prob.unsqueeze(-1), Nb_h) + torch.matmul((1 - topk_prob).unsqueeze(-1), e_h.unsqueeze(2)) # 1 pixel wise 2 matmul
# gated knowledge attention
e_h_expand = e_h.unsqueeze(2).expand(-1, -1, self.topk, -1)
gate = torch.tanh(e_h_expand + eh_r)
ka_weight = torch.einsum('ijkl,ijkm->ijk', Nb_h, gate)
ka_prob = F.softmax(ka_weight, dim=2).unsqueeze(dim=2)
e_Nh = torch.matmul(ka_prob, Nb_h).squeeze(dim=2)
if self.agg_type == 'gcn':
embedding = e_h + e_Nh
embedding = self.activation(self.linear(embedding))
elif self.agg_type == 'sage':
embedding = torch.cat([e_h, e_Nh], dim=2)
embedding = self.activation(self.linear(embedding))
elif self.agg_type == 'bi-interaction':
sum_embedding = self.activation(self.linear1(e_h + e_Nh))
bi_embedding = self.activation(self.linear2(e_h * e_Nh))
embedding = sum_embedding + bi_embedding
h = self.message_dropout(embedding)
h = self.readout(h.squeeze(0), batch=None)
h = self.norm(h)
h = self.fc(h)
return h
if __name__ == "__main__":
data = torch.randn((1, 10000, 384)).cuda()
model = WiKG(dim_in=384, dim_hidden=512, topk=6, n_classes=2, agg_type='bi-interaction', dropout=0.3, pool='attn').cuda()
output = model(data)
print(output.shape)