Skip to content

Commit 751f63e

Browse files
committed
seperate pooling for atom and augmented nodes
One vector: average of atom embeddings One vector: average of augmented node embeddings #2 (comment)
1 parent 243722b commit 751f63e

File tree

5 files changed

+105
-3
lines changed

5 files changed

+105
-3
lines changed

chebai_graph/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from ._gat import GATModelWrapper
2+
from .graph import ResGatedAugmentedGraphPred
23

3-
__all__ = ["GATModelWrapper"]
4+
__all__ = ["GATModelWrapper", "ResGatedAugmentedGraphPred"]

chebai_graph/models/graph.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,68 @@ def forward(self, batch):
188188
return a
189189

190190

191+
class ResGatedAugmentedGraphPred(GraphBaseNet):
192+
"""GNN for graph-level prediction for augmented graphs"""
193+
194+
NAME = "ResGatedAugmentedGraphPred"
195+
196+
def __init__(
197+
self,
198+
config: typing.Dict,
199+
n_linear_layers=2,
200+
**kwargs,
201+
):
202+
super().__init__(**kwargs)
203+
self.gnn = ResGatedGraphConvNetBase(config, **kwargs)
204+
self.linear_layers = torch.nn.ModuleList(
205+
[
206+
torch.nn.Linear(
207+
self.gnn.hidden_length
208+
+ (i == 0) * self.gnn.n_molecule_properties
209+
+ (i == 0) * self.gnn.hidden_length,
210+
self.gnn.hidden_length,
211+
)
212+
for i in range(n_linear_layers - 1)
213+
]
214+
)
215+
self.final_layer = nn.Linear(self.gnn.hidden_length, self.out_dim)
216+
217+
def forward(self, batch):
218+
graph_data = batch["features"][0]
219+
assert isinstance(graph_data, GraphData)
220+
is_atom_node = graph_data.is_atom_node.bool() # Boolean mask: shape [num_nodes]
221+
is_augmented_node = ~is_atom_node
222+
223+
node_embeddings = self.gnn(batch)
224+
225+
atom_embeddings = node_embeddings[is_atom_node]
226+
atom_batch = graph_data.batch[is_atom_node]
227+
228+
augmented_node_embeddings = node_embeddings[is_augmented_node]
229+
augmented_node_batch = graph_data.batch[is_augmented_node]
230+
231+
# Scatter add separately
232+
graph_vec_atoms = scatter_add(atom_embeddings, atom_batch, dim=0)
233+
graph_vec_augmented_nodes = scatter_add(
234+
augmented_node_embeddings, augmented_node_batch, dim=0
235+
)
236+
237+
# Concatenate both
238+
graph_vector = torch.cat(
239+
[
240+
graph_vec_atoms,
241+
graph_data.molecule_attr,
242+
graph_vec_augmented_nodes,
243+
],
244+
dim=1,
245+
)
246+
247+
for lin in self.linear_layers:
248+
a = self.gnn.activation(lin(graph_vector))
249+
a = self.final_layer(a)
250+
return a
251+
252+
191253
class ResGatedGraphConvNetPretrain(GraphBaseNet):
192254
"""For pretraining. BaseNet with an additional output layer for predicting atom properties"""
193255

chebai_graph/preprocessing/datasets/chebi.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,11 +178,16 @@ def _merge_props_into_base(self, row):
178178
)
179179
else:
180180
molecule_attr = torch.cat([molecule_attr, property_values], dim=1)
181+
182+
is_atom_node = (
183+
geom_data.is_atom_node if hasattr(geom_data, "is_atom_node") else None
184+
)
181185
return GeomData(
182186
x=x,
183187
edge_index=geom_data.edge_index,
184188
edge_attr=edge_attr,
185189
molecule_attr=molecule_attr,
190+
is_atom_node=is_atom_node,
186191
)
187192

188193
def load_processed_data_from_file(self, filename):

chebai_graph/preprocessing/reader/augmented_reader.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
from abc import ABC, abstractmethod
23
from typing import Dict, List, Optional, Tuple
34

@@ -179,7 +180,12 @@ def _read_data(self, smiles: str) -> GeomData | None:
179180
self.mol_object_buffer[smiles] = augmented_molecule
180181

181182
# Empty features initialized; node and edge features can be added later
182-
x = torch.zeros((augmented_molecule["nodes"]["num_nodes"], 0))
183+
NUM_NODES = augmented_molecule["nodes"]["num_nodes"]
184+
assert (
185+
NUM_NODES is not None and NUM_NODES > 1
186+
), "Num of nodes in augmented graph should be more than 1"
187+
188+
x = torch.zeros((NUM_NODES, 0))
183189
edge_attr = torch.zeros((augmented_molecule["edges"][k.NUM_EDGES], 0))
184190

185191
assert (
@@ -194,7 +200,14 @@ def _read_data(self, smiles: str) -> GeomData | None:
194200
len(set(edge_index[0].tolist())) == x.shape[0]
195201
), f"Number of unique source nodes in edge_index ({len(set(edge_index[0].tolist()))}) does not match number of nodes in x ({x.shape[0]})"
196202

197-
return GeomData(x=x, edge_index=edge_index, edge_attr=edge_attr)
203+
# Create a boolean mask: True for atom, False for augmented
204+
is_atom_mask = torch.zeros(NUM_NODES, dtype=torch.bool)
205+
NUM_ATOM_NODES = augmented_molecule["nodes"]["atom_nodes"].GetNumAtoms()
206+
is_atom_mask[:NUM_ATOM_NODES] = True
207+
208+
return GeomData(
209+
x=x, edge_index=edge_index, edge_attr=edge_attr, is_atom_node=is_atom_mask
210+
)
198211

199212
def _create_augmented_graph(self, mol: Chem.Mol) -> Tuple[torch.Tensor, dict]:
200213
"""
@@ -267,6 +280,14 @@ def _augment_graph_structure(
267280
assert (
268281
self._num_of_nodes == total_atoms
269282
), f"Mismatch in number of nodes: expected {total_atoms}, got {self._num_of_nodes}"
283+
assert sys.version_info >= (
284+
3,
285+
7,
286+
), "This code requires Python 3.7 or higher."
287+
# For python 3.7+, the standard dict type preserves insertion order, and is iterated over in same order
288+
# https://docs.python.org/3/whatsnew/3.7.html#summary-release-highlights
289+
# https://mail.python.org/pipermail/python-dev/2017-December/151283.html
290+
# Order preservation is necessary to to create `is_atom_node` mask
270291
node_info = {
271292
"atom_nodes": mol,
272293
"fg_nodes": fg_nodes,

configs/model/gnn_resgated_aug.yml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
class_path: chebai_graph.models.ResGatedAugmentedGraphPred
2+
init_args:
3+
optimizer_kwargs:
4+
lr: 1e-3
5+
config:
6+
in_length: 256
7+
hidden_length: 512
8+
dropout_rate: 0.1
9+
n_conv_layers: 3
10+
n_linear_layers: 3
11+
n_atom_properties: 158
12+
n_bond_properties: 7
13+
n_molecule_properties: 200

0 commit comments

Comments
 (0)