|
| 1 | +from typing import Any |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch import Tensor |
| 5 | +from torch.nn import ELU |
| 6 | +from torch_geometric.data import Data as GraphData |
| 7 | +from torch_geometric.nn.models.basic_gnn import BasicGNN |
| 8 | + |
| 9 | +from chebai_graph.preprocessing.reader import RandomFeatureInitializationReader |
| 10 | + |
| 11 | +from .base import GraphModelBase, GraphNetWrapper |
| 12 | +from .resgated import ResGatedModel |
| 13 | + |
| 14 | + |
| 15 | +class ResGatedDynamicGNI(GraphModelBase): |
| 16 | + """ |
| 17 | + Base model class for applying ResGatedGraphConv layers to graph-structured data |
| 18 | + with dynamic initialization of features for nodes and edges. |
| 19 | +
|
| 20 | + Args: |
| 21 | + config (dict): Configuration dictionary containing model hyperparameters. |
| 22 | + **kwargs: Additional keyword arguments for parent class. |
| 23 | + """ |
| 24 | + |
| 25 | + def __init__(self, config: dict[str, Any], **kwargs: Any): |
| 26 | + super().__init__(config=config, **kwargs) |
| 27 | + self.activation = ELU() # Instantiate ELU once for reuse. |
| 28 | + distribution = config.get("distribution", "normal") |
| 29 | + assert distribution in ["normal", "uniform", "xavier_normal", "xavier_uniform"] |
| 30 | + self.distribution = distribution |
| 31 | + |
| 32 | + self.resgated: BasicGNN = ResGatedModel( |
| 33 | + in_channels=self.in_channels, |
| 34 | + hidden_channels=self.hidden_channels, |
| 35 | + out_channels=self.out_channels, |
| 36 | + num_layers=self.num_layers, |
| 37 | + edge_dim=self.edge_dim, |
| 38 | + act=self.activation, |
| 39 | + ) |
| 40 | + |
| 41 | + def forward(self, batch: dict[str, Any]) -> Tensor: |
| 42 | + """ |
| 43 | + Forward pass of the model. |
| 44 | +
|
| 45 | + Args: |
| 46 | + batch (dict): A batch containing graph input features under the key "features". |
| 47 | +
|
| 48 | + Returns: |
| 49 | + Tensor: The output node-level embeddings after the final activation. |
| 50 | + """ |
| 51 | + graph_data = batch["features"][0] |
| 52 | + assert isinstance(graph_data, GraphData), "Expected GraphData instance" |
| 53 | + |
| 54 | + random_x = torch.empty(graph_data.x.shape[0], graph_data.x.shape[1]) |
| 55 | + RandomFeatureInitializationReader.random_gni(random_x, self.distribution) |
| 56 | + random_edge_attr = torch.empty( |
| 57 | + graph_data.edge_attr.shape[0], graph_data.edge_attr.shape[1] |
| 58 | + ) |
| 59 | + RandomFeatureInitializationReader.random_gni( |
| 60 | + random_edge_attr, self.distribution |
| 61 | + ) |
| 62 | + |
| 63 | + out = self.resgated( |
| 64 | + x=graph_data.x.float(), |
| 65 | + edge_index=graph_data.edge_index.long(), |
| 66 | + edge_attr=graph_data.edge_attr, |
| 67 | + ) |
| 68 | + |
| 69 | + return self.activation(out) |
| 70 | + |
| 71 | + |
| 72 | +class ResGatedDynamicGNIGraphPred(GraphNetWrapper): |
| 73 | + """ |
| 74 | + Wrapper for graph-level prediction using ResGatedDynamicGNI. |
| 75 | +
|
| 76 | + This class instantiates the core GNN model using the provided config. |
| 77 | + """ |
| 78 | + |
| 79 | + def _get_gnn(self, config: dict[str, Any]) -> ResGatedDynamicGNI: |
| 80 | + """ |
| 81 | + Returns the core ResGated GNN model. |
| 82 | +
|
| 83 | + Args: |
| 84 | + config (dict): Configuration dictionary for the GNN model. |
| 85 | +
|
| 86 | + Returns: |
| 87 | + ResGatedDynamicGNI: The core graph convolutional network. |
| 88 | + """ |
| 89 | + return ResGatedDynamicGNI(config=config) |
0 commit comments