diff --git a/model.py b/model.py index 2fd3377..48497a7 100644 --- a/model.py +++ b/model.py @@ -34,7 +34,7 @@ def __init__(self, feature_size, model_params): self.bn1 = BatchNorm1d(embedding_size) # Other layers - for i in range(self.n_layers): + for i in range(len(self.n_layers)): self.conv_layers.append(TransformerConv(embedding_size, embedding_size, heads=n_heads, @@ -62,7 +62,7 @@ def forward(self, x, edge_attr, edge_index, batch_index): # Holds the intermediate graph representations global_representation = [] - for i in range(self.n_layers): + for i in range(len(self.n_layers)): x = self.conv_layers[i](x, edge_index, edge_attr) x = torch.relu(self.transf_layers[i](x)) x = self.bn_layers[i](x)