-
Notifications
You must be signed in to change notification settings - Fork 53
Open
Description
你好!最近我在尝试用自己做的基因组数据集训练模型,因此需要把我的数据集改成模型可以接收并进行训练的格式。我有两个问题想请教一下:
- 下面是你们的原代码,我只是把类名修改了一下...我想知道在forward()里features_list是什么呢?features_list是图的节点字典转换过去的列表,这个列表的长度应该是num_of_ntypes,每个元素是该节点类型下所有节点嵌入的张量吗?
def __init__(self,
g,
edge_dim,
num_etypes,
in_dims,
num_hidden,
num_classes,
num_layers,
heads,
activation,
feat_drop,
attn_drop,
negative_slope,
residual,
alpha):
super(NodeClassificationHGAT, self).__init__()
self.g = g
self.num_layers = num_layers
self.gat_layers = nn.ModuleList()
self.activation = activation
self.fc_list = nn.ModuleList([nn.Linear(in_dim, num_hidden, bias=True) for in_dim in in_dims])
for fc in self.fc_list:
nn.init.xavier_normal_(fc.weight, gain=1.414)
# input projection (no residual)
self.gat_layers.append(
NodeClassificationHGATConv(
edge_dim, #edge_feats
num_etypes, #num_etypes
num_hidden, #in_feats
num_hidden, #out_feats
heads[0], #num_heads
feat_drop, #feat_drop
attn_drop, #attn_frop
negative_slope, #leaky_relu
False, #residual
self.activation,
True, #allow_zero_in_degree
True,#bias
alpha=alpha))
# hidden layers
for l in range(1, num_layers):
# due to multi-head, the in_feats = num_hidden * num_heads
self.gat_layers.append(
NodeClassificationHGATConv(
edge_dim,
num_etypes,
num_hidden * heads[l-1],
num_hidden,
heads[l],
feat_drop,
attn_drop,
negative_slope,
residual,
self.activation,
True,
True,
alpha=alpha))
# output projection
self.gat_layers.append(
NodeClassificationHGATConv(
edge_dim,
num_etypes,
num_hidden * heads[-2],
num_classes,
heads[-1],
feat_drop,
attn_drop,
negative_slope,
residual,
None,
True,
alpha=alpha))
self.epsilon = torch.FloatTensor([1e-12]).cuda()
def forward(self, features_list, e_feat):
h = []
for fc, feature in zip(self.fc_list, features_list):
h.append(fc(feature))
h = torch.cat(h, 0)
res_attn = None
for l in range(self.num_layers):
h, res_attn = self.gat_layers[l](self.g, h, e_feat, res_attn=res_attn)
h = h.flatten(1)
# output projection
logits, _ = self.gat_layers[-1](self.g, h, e_feat, res_attn=None)
logits = logits.mean(1)
# This is an equivalent replacement for tf.l2_normalize, see https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/math/l2_normalize for more information.
logits = logits / (torch.max(torch.norm(logits, dim=1, keepdim=True), self.epsilon))
return logits
当我尝试直接将h用graph.ntypes下的节点嵌入作为输入时
出现了以下错误:
``` Traceback (most recent call last):
File "/root/hhs/gene_pretrain/graph_pretrain/graph_model/SimpleHGAT/GraphNodeClassification.py", line 385, in <module>
train_acc, val_acc, test_acc, precision, recall, f1 = train(graph, args)
^^^^^^^^^^^^^^^^^^
File "/root/hhs/gene_pretrain/graph_pretrain/graph_model/SimpleHGAT/GraphNodeClassification.py", line 206, in train
train_pred_labels = model(train_features_list, e_feat_tensor)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/anaconda3/envs/ocean/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/hhs/gene_pretrain/graph_pretrain/graph_model/SimpleHGAT/HGAT.py", line 94, in forward
h, res_attn = self.gat_layers[l](self.g, h, e_feat, res_attn=res_attn)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/anaconda3/envs/ocean/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/hhs/gene_pretrain/graph_pretrain/graph_model/SimpleHGAT/conv.py", line 117, in forward
graph.srcdata.update({'ft': feat_src, 'el': el})
File "<frozen _collections_abc>", line 949, in update
File "/data/anaconda3/envs/ocean/lib/python3.11/site-packages/dgl/view.py", line 86, in __setitem__
assert isinstance(val, dict), (
AssertionError: Current HeteroNodeDataView has multiple node types, please passing the node type and the corresponding data through a dict.
这是正常的吗?有什么好的通用的解决方法吗?
2. 在你们根据GAT层修改的HGAT层原代码中,
``` if isinstance(feat, tuple):
h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1])
if not hasattr(self, 'fc_src'):
self.fc_src, self.fc_dst = self.fc, self.fc
feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats)
feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats)
else:
h_src = h_dst = self.feat_drop(feat)
feat_src = feat_dst = self.fc(h_src).view(
-1, self._num_heads, self._out_feats)
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]
这个部分里,这个
``` if isinstance(feat, tuple):
判断语句是在什么情况下触发的呢?
造成这个困惑的原因是
``` def forward(self, features_list, e_feat):
h = []
for fc, feature in zip(self.fc_list, features_list):
h.append(fc(feature))
h = torch.cat(h, 0)
这个地方得到的一定是个张量对吧?所以引用的部分代码中,只能触发else之后的代码吗?
感谢你们的工作!
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels