Skip to content

Commit f52d6e3

Browse files
committed
more class for ablation studies
1 parent 7efbea5 commit f52d6e3

File tree

3 files changed

+99
-0
lines changed

3 files changed

+99
-0
lines changed

chebai_graph/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from .augmented import (
22
ResGatedAugNodePoolGraphPred,
3+
ResGatedFGNodeNoGraphNodeGraphPred,
34
ResGatedFGNodePoolGraphPred,
45
ResGatedGraphNodeFGNodePoolGraphPred,
6+
ResGatedGraphNodeNoFGNodeGraphPred,
57
ResGatedGraphNodePoolGraphPred,
68
)
79
from .gat import GATGraphPred
@@ -10,8 +12,10 @@
1012
__all__ = [
1113
"GATGraphPred",
1214
"ResGatedGraphPred",
15+
"ResGatedFGNodeNoGraphNodeGraphPred",
1316
"ResGatedAugNodePoolGraphPred",
1417
"ResGatedGraphNodeFGNodePoolGraphPred",
1518
"ResGatedGraphNodePoolGraphPred",
19+
"ResGatedGraphNodeNoFGNodeGraphPred",
1620
"ResGatedFGNodePoolGraphPred",
1721
]

chebai_graph/models/augmented.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from .base import (
22
AugmentedNodePoolingNet,
33
FGNodePoolingNet,
4+
FGNodePoolingNoGraphNodeNet,
45
GraphNodeFGNodePoolingNet,
6+
GraphNodeNoFGNodePoolingNet,
57
GraphNodePoolingNet,
68
)
79
from .resgated import ResGatedGraphPred
@@ -31,3 +33,19 @@ class ResGatedGraphNodeFGNodePoolGraphPred(
3133
"""GNN for graph-level prediction for augmented graphs"""
3234

3335
NAME = "ResGatedGraphNodeFGNodePoolGraphPred"
36+
37+
38+
class ResGatedGraphNodeNoFGNodeGraphPred(
39+
GraphNodeNoFGNodePoolingNet, ResGatedGraphPred
40+
):
41+
"""GNN for graph-level prediction for augmented graphs without FG nodes"""
42+
43+
NAME = "ResGatedGraphNodeNoFGNodeGraphPred"
44+
45+
46+
class ResGatedFGNodeNoGraphNodeGraphPred(
47+
FGNodePoolingNoGraphNodeNet, ResGatedGraphPred
48+
):
49+
"""GNN for graph-level prediction for augmented graphs without FG nodes"""
50+
51+
NAME = "ResGatedFGNodeNoGraphNodeGraphPred"

chebai_graph/models/base.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,3 +237,80 @@ def forward(self, batch):
237237
)
238238

239239
return self.lin_sequential(graph_vector)
240+
241+
242+
class FGNodePoolingNoGraphNodeNet(GraphNetWrapper, ABC):
243+
"""Graph Node not considered here in any computation"""
244+
245+
def _get_lin_seq_input_dim(self, gnn_out_dim, n_molecule_properties):
246+
# atom_embeddings + molecule attributes + functional_group_node_embeddings
247+
return gnn_out_dim + n_molecule_properties + gnn_out_dim
248+
249+
def forward(self, batch):
250+
graph_data = batch["features"][0]
251+
assert isinstance(graph_data, GraphData)
252+
is_graph_node = graph_data.is_graph_node.bool()
253+
is_atom_node = graph_data.is_atom_node.bool()
254+
is_fg_node = (~is_atom_node) & (~is_graph_node)
255+
256+
node_embeddings = self.gnn(batch)
257+
258+
atom_embeddings = node_embeddings[is_atom_node]
259+
atom_batch = graph_data.batch[is_atom_node]
260+
261+
fg_node_embeddings = node_embeddings[is_fg_node]
262+
fg_node_batch = graph_data.batch[is_fg_node]
263+
264+
# Scatter add separately
265+
atom_vec = scatter_add(atom_embeddings, atom_batch, dim=0)
266+
fg_node_vec = scatter_add(fg_node_embeddings, fg_node_batch, dim=0)
267+
268+
# Concatenate all
269+
graph_vector = torch.cat(
270+
[
271+
atom_vec,
272+
graph_data.molecule_attr,
273+
fg_node_vec,
274+
],
275+
dim=1,
276+
)
277+
278+
return self.lin_sequential(graph_vector)
279+
280+
281+
class GraphNodeNoFGNodePoolingNet(GraphNetWrapper, ABC):
282+
"""Functional Group Nodes not considered here in any computation"""
283+
284+
def _get_lin_seq_input_dim(self, gnn_out_dim, n_molecule_properties):
285+
# atom_embeddings + molecule attributes + graph_node_embeddings
286+
return gnn_out_dim + n_molecule_properties + gnn_out_dim
287+
288+
def forward(self, batch):
289+
graph_data = batch["features"][0]
290+
assert isinstance(graph_data, GraphData)
291+
is_graph_node = graph_data.is_graph_node.bool()
292+
is_atom_node = graph_data.is_atom_node.bool()
293+
294+
node_embeddings = self.gnn(batch)
295+
296+
graph_node_embedding = node_embeddings[is_graph_node]
297+
graph_node_batch = graph_data.batch[is_graph_node]
298+
299+
atom_embeddings = node_embeddings[is_atom_node]
300+
atom_batch = graph_data.batch[is_atom_node]
301+
302+
# Scatter add separately
303+
graph_node_vec = scatter_add(graph_node_embedding, graph_node_batch, dim=0)
304+
atom_vec = scatter_add(atom_embeddings, atom_batch, dim=0)
305+
306+
# Concatenate all
307+
graph_vector = torch.cat(
308+
[
309+
atom_vec,
310+
graph_data.molecule_attr,
311+
graph_node_vec,
312+
],
313+
dim=1,
314+
)
315+
316+
return self.lin_sequential(graph_vector)

0 commit comments

Comments
 (0)