-
Notifications
You must be signed in to change notification settings - Fork 11
Problem in class IndexSelect #4
Description
Hi, your GXN work is interesting and I'm trying to implement it using DGL. However, I find something strange in your code.
Specifically, in your paper, the criterion function for VIPool involves two function values: T_w(\mathbf{x}_v, \mathbf{y}_{\mathcal{N}_v}) and T_w(\mathbf{x}_v, \mathbf{y}_{\mathcal{N}_u}), where T_w(\mathbf{x}_v, \mathbf{y}_{\mathcal{N}_u}) = \mathcal{S}_w(\mathcal{E}_w(\mathbf{x}_v), \mathcal{P}_w(\mathbf{y}_{\mathcal{N}_u})). And as described in your paper, \mathcal{E} is MLP, \mathcal{P} is some Message-Passing layer.
However, in your code, it seems like you implement these in the form of T_w(\mathbf{y}_{\mathcal{N}_v}, \mathbf{x}_v) and T_w(\mathbf{y}_{\mathcal{N}_u}, \mathbf{x}_v). More specifically, in the class IndexSelect of your code:
class IndexSelect(nn.Module):
def __init__(self, k, n_h, act, R=1):
super().__init__()
self.k = k
self.R = R
self.sigm = nn.Sigmoid()
self.fc = MLP(n_h, n_h, act)
self.disc = Discriminator(n_h)
self.gcn1 = GCN(n_h, n_h)
def forward(self, seq1, seq2, A, samp_bias1=None, samp_bias2=None):
h_1 = self.fc(seq1)
h_2 = self.fc(seq2)
h_n1 = self.gcn1(A, h_1)
X = self.sigm(h_n1)
ret, ret_true = self.disc(X, h_1, h_2, samp_bias1, samp_bias2)
scores = self.sigm(ret_true).squeeze()
num_nodes = A.shape[0]
values, idx = torch.topk(scores, int(num_nodes))
values1, idx1 = values[:int(self.k*num_nodes)], idx[:int(self.k*num_nodes)]
values0, idx0 = values[int(self.k*num_nodes):], idx[int(self.k*num_nodes):]
return ret, values1, idx1, idx0, h_n1Looks like only X is the output of GCN, while h_1 and h_2 (In my understanding, they represent for \mathcal{P}_w(\mathbf{y}_{\mathcal{N}_v}) and \mathcal{P}_w(\mathbf{y}_{\mathcal{N}_u}) respectively) are output of MLP. If we follow the setting in your paper, shouldn't h_1 and h_2 be the output of GCN?