Skip to content

Commit 79fc500

Browse files
committed
dynamic gni
1 parent a0d6ea7 commit 79fc500

File tree

6 files changed

+125
-23
lines changed

6 files changed

+125
-23
lines changed

chebai_graph/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
ResGatedAugNodePoolGraphPred,
55
ResGatedGraphNodeFGNodePoolGraphPred,
66
)
7+
from .dynamic_gni import ResGatedDynamicGNIGraphPred
78
from .gat import GATGraphPred
89
from .resgated import ResGatedGraphPred
910

@@ -14,4 +15,5 @@
1415
"GATGraphPred",
1516
"GATAugNodePoolGraphPred",
1617
"GATGraphNodeFGNodePoolGraphPred",
18+
"ResGatedDynamicGNIGraphPred",
1719
]

chebai_graph/models/dynamic_gni.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from typing import Any
2+
3+
import torch
4+
from torch import Tensor
5+
from torch.nn import ELU
6+
from torch_geometric.data import Data as GraphData
7+
from torch_geometric.nn.models.basic_gnn import BasicGNN
8+
9+
from chebai_graph.preprocessing.reader import RandomFeatureInitializationReader
10+
11+
from .base import GraphModelBase, GraphNetWrapper
12+
from .resgated import ResGatedModel
13+
14+
15+
class ResGatedDynamicGNI(GraphModelBase):
16+
"""
17+
Base model class for applying ResGatedGraphConv layers to graph-structured data
18+
with dynamic initialization of features for nodes and edges.
19+
20+
Args:
21+
config (dict): Configuration dictionary containing model hyperparameters.
22+
**kwargs: Additional keyword arguments for parent class.
23+
"""
24+
25+
def __init__(self, config: dict[str, Any], **kwargs: Any):
26+
super().__init__(config=config, **kwargs)
27+
self.activation = ELU() # Instantiate ELU once for reuse.
28+
distribution = config.get("distribution", "normal")
29+
assert distribution in ["normal", "uniform", "xavier_normal", "xavier_uniform"]
30+
self.distribution = distribution
31+
32+
self.resgated: BasicGNN = ResGatedModel(
33+
in_channels=self.in_channels,
34+
hidden_channels=self.hidden_channels,
35+
out_channels=self.out_channels,
36+
num_layers=self.num_layers,
37+
edge_dim=self.edge_dim,
38+
act=self.activation,
39+
)
40+
41+
def forward(self, batch: dict[str, Any]) -> Tensor:
42+
"""
43+
Forward pass of the model.
44+
45+
Args:
46+
batch (dict): A batch containing graph input features under the key "features".
47+
48+
Returns:
49+
Tensor: The output node-level embeddings after the final activation.
50+
"""
51+
graph_data = batch["features"][0]
52+
assert isinstance(graph_data, GraphData), "Expected GraphData instance"
53+
54+
random_x = torch.empty(graph_data.x.shape[0], graph_data.x.shape[1])
55+
RandomFeatureInitializationReader.random_gni(random_x, self.distribution)
56+
random_edge_attr = torch.empty(
57+
graph_data.edge_attr.shape[0], graph_data.edge_attr.shape[1]
58+
)
59+
RandomFeatureInitializationReader.random_gni(
60+
random_edge_attr, self.distribution
61+
)
62+
63+
out = self.resgated(
64+
x=graph_data.x.float(),
65+
edge_index=graph_data.edge_index.long(),
66+
edge_attr=graph_data.edge_attr,
67+
)
68+
69+
return self.activation(out)
70+
71+
72+
class ResGatedDynamicGNIGraphPred(GraphNetWrapper):
73+
"""
74+
Wrapper for graph-level prediction using ResGatedDynamicGNI.
75+
76+
This class instantiates the core GNN model using the provided config.
77+
"""
78+
79+
def _get_gnn(self, config: dict[str, Any]) -> ResGatedDynamicGNI:
80+
"""
81+
Returns the core ResGated GNN model.
82+
83+
Args:
84+
config (dict): Configuration dictionary for the GNN model.
85+
86+
Returns:
87+
ResGatedDynamicGNI: The core graph convolutional network.
88+
"""
89+
return ResGatedDynamicGNI(config=config)

chebai_graph/preprocessing/datasets/chebi.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
GN_WithAtoms_FG_WithAtoms_NoFGE,
3737
GraphPropertyReader,
3838
GraphReader,
39-
RandomNodeInitializationReader,
39+
RandomFeatureInitializationReader,
4040
)
4141

4242
from .utils import resolve_property
@@ -518,7 +518,7 @@ def _merge_props_into_base(
518518

519519

520520
class ChEBI50_StaticGNI(DataPropertiesSetter, ChEBIOver50):
521-
READER = RandomNodeInitializationReader
521+
READER = RandomFeatureInitializationReader
522522

523523
def _setup_properties(self): ...
524524

chebai_graph/preprocessing/reader/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
GN_WithAtoms_FG_WithAtoms_NoFGE,
1111
)
1212
from .reader import GraphPropertyReader, GraphReader
13-
from .static_gni import RandomNodeInitializationReader
13+
from .static_gni import RandomFeatureInitializationReader
1414

1515
__all__ = [
1616
"GraphReader",
@@ -20,7 +20,7 @@
2020
"AtomFGReader_NoFGEdges_WithGraphNode",
2121
"AtomFGReader_WithFGEdges_NoGraphNode",
2222
"AtomFGReader_WithFGEdges_WithGraphNode",
23-
"RandomNodeInitializationReader",
23+
"RandomFeatureInitializationReader",
2424
"GN_WithAtoms_FG_WithAtoms_FGE",
2525
"GN_WithAtoms_FG_WithAtoms_NoFGE",
2626
"GN_WithAllNodes_FG_WithAtoms_FGE",

chebai_graph/preprocessing/reader/static_gni.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from .reader import GraphPropertyReader
1313

1414

15-
class RandomNodeInitializationReader(GraphPropertyReader):
15+
class RandomFeatureInitializationReader(GraphPropertyReader):
1616
def __init__(
1717
self,
1818
num_node_properties: int,
@@ -46,24 +46,9 @@ def _read_data(self, raw_data):
4646
)
4747
random_molecule_properties = torch.empty(1, self.num_molecule_properties)
4848

49-
if self.distribution == "normal":
50-
torch.nn.init.normal_(random_x)
51-
torch.nn.init.normal_(random_edge_attr)
52-
torch.nn.init.normal_(random_molecule_properties)
53-
elif self.distribution == "uniform":
54-
torch.nn.init.uniform_(random_x, a=-1.0, b=1.0)
55-
torch.nn.init.uniform_(random_edge_attr, a=-1.0, b=1.0)
56-
torch.nn.init.uniform_(random_molecule_properties, a=-1.0, b=1.0)
57-
elif self.distribution == "xavier_normal":
58-
torch.nn.init.xavier_normal_(random_x)
59-
torch.nn.init.xavier_normal_(random_edge_attr)
60-
torch.nn.init.xavier_normal_(random_molecule_properties)
61-
elif self.distribution == "xavier_uniform":
62-
torch.nn.init.xavier_uniform_(random_x)
63-
torch.nn.init.xavier_uniform_(random_edge_attr)
64-
torch.nn.init.xavier_uniform_(random_molecule_properties)
65-
else:
66-
raise ValueError("Unknown distribution type")
49+
self.random_gni(random_x, self.distribution)
50+
self.random_gni(random_edge_attr, self.distribution)
51+
self.random_gni(random_molecule_properties, self.distribution)
6752

6853
data.x = random_x
6954
data.edge_attr = random_edge_attr
@@ -73,3 +58,16 @@ def _read_data(self, raw_data):
7358
def read_property(self, *args, **kwargs) -> Exception:
7459
"""This reader does not support reading specific properties."""
7560
raise NotImplementedError("This reader only performs random initialization.")
61+
62+
@staticmethod
63+
def random_gni(tensor: torch.Tensor, distribution: str) -> None:
64+
if distribution == "normal":
65+
torch.nn.init.normal_(tensor)
66+
elif distribution == "uniform":
67+
torch.nn.init.uniform_(tensor, a=-1.0, b=1.0)
68+
elif distribution == "xavier_normal":
69+
torch.nn.init.xavier_normal_(tensor)
70+
elif distribution == "xavier_uniform":
71+
torch.nn.init.xavier_uniform_(tensor)
72+
else:
73+
raise ValueError("Unknown distribution type")
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
class_path: chebai_graph.models.ResGatedDynamicGNIGraphPred
2+
init_args:
3+
optimizer_kwargs:
4+
lr: 1e-3
5+
config:
6+
in_channels: 158 # number of node/atom properties
7+
hidden_channels: 256
8+
out_channels: 512
9+
num_layers: 4
10+
edge_dim: 7 # number of bond properties
11+
dropout: 0
12+
n_molecule_properties: 0
13+
n_linear_layers: 1

0 commit comments

Comments
 (0)