@@ -237,3 +237,80 @@ def forward(self, batch):
237237 )
238238
239239 return self .lin_sequential (graph_vector )
240+
241+
242+ class FGNodePoolingNoGraphNodeNet (GraphNetWrapper , ABC ):
243+ """Graph Node not considered here in any computation"""
244+
245+ def _get_lin_seq_input_dim (self , gnn_out_dim , n_molecule_properties ):
246+ # atom_embeddings + molecule attributes + functional_group_node_embeddings
247+ return gnn_out_dim + n_molecule_properties + gnn_out_dim
248+
249+ def forward (self , batch ):
250+ graph_data = batch ["features" ][0 ]
251+ assert isinstance (graph_data , GraphData )
252+ is_graph_node = graph_data .is_graph_node .bool ()
253+ is_atom_node = graph_data .is_atom_node .bool ()
254+ is_fg_node = (~ is_atom_node ) & (~ is_graph_node )
255+
256+ node_embeddings = self .gnn (batch )
257+
258+ atom_embeddings = node_embeddings [is_atom_node ]
259+ atom_batch = graph_data .batch [is_atom_node ]
260+
261+ fg_node_embeddings = node_embeddings [is_fg_node ]
262+ fg_node_batch = graph_data .batch [is_fg_node ]
263+
264+ # Scatter add separately
265+ atom_vec = scatter_add (atom_embeddings , atom_batch , dim = 0 )
266+ fg_node_vec = scatter_add (fg_node_embeddings , fg_node_batch , dim = 0 )
267+
268+ # Concatenate all
269+ graph_vector = torch .cat (
270+ [
271+ atom_vec ,
272+ graph_data .molecule_attr ,
273+ fg_node_vec ,
274+ ],
275+ dim = 1 ,
276+ )
277+
278+ return self .lin_sequential (graph_vector )
279+
280+
281+ class GraphNodeNoFGNodePoolingNet (GraphNetWrapper , ABC ):
282+ """Functional Group Nodes not considered here in any computation"""
283+
284+ def _get_lin_seq_input_dim (self , gnn_out_dim , n_molecule_properties ):
285+ # atom_embeddings + molecule attributes + graph_node_embeddings
286+ return gnn_out_dim + n_molecule_properties + gnn_out_dim
287+
288+ def forward (self , batch ):
289+ graph_data = batch ["features" ][0 ]
290+ assert isinstance (graph_data , GraphData )
291+ is_graph_node = graph_data .is_graph_node .bool ()
292+ is_atom_node = graph_data .is_atom_node .bool ()
293+
294+ node_embeddings = self .gnn (batch )
295+
296+ graph_node_embedding = node_embeddings [is_graph_node ]
297+ graph_node_batch = graph_data .batch [is_graph_node ]
298+
299+ atom_embeddings = node_embeddings [is_atom_node ]
300+ atom_batch = graph_data .batch [is_atom_node ]
301+
302+ # Scatter add separately
303+ graph_node_vec = scatter_add (graph_node_embedding , graph_node_batch , dim = 0 )
304+ atom_vec = scatter_add (atom_embeddings , atom_batch , dim = 0 )
305+
306+ # Concatenate all
307+ graph_vector = torch .cat (
308+ [
309+ atom_vec ,
310+ graph_data .molecule_attr ,
311+ graph_node_vec ,
312+ ],
313+ dim = 1 ,
314+ )
315+
316+ return self .lin_sequential (graph_vector )
0 commit comments