diff --git a/model.py b/model.py index d222e54..f4bcb1f 100644 --- a/model.py +++ b/model.py @@ -103,8 +103,8 @@ def __init__(self, state_dim, annotation_dim, n_edge_types, n_nodes, m.weight.data.normal_(0.0, 0.02) m.bias.data.fill_(0) - self.in_fcs = AttrProxy(self, 'in_') - self.out_fcs = AttrProxy(self, 'out_') + self.in_fcs = ListModule(self, 'in_') + self.out_fcs = ListModule(self, 'out_') def forward(self, init_hidden_state, annotation, adj_matrix):