Skip to content

Graph Neural Network Architectures#2

Merged
NavidCOMSC merged 71 commits intomainfrom
gnn
Sep 9, 2025
Merged

Graph Neural Network Architectures#2
NavidCOMSC merged 71 commits intomainfrom
gnn

Conversation

@NavidCOMSC
Copy link
Copy Markdown
Owner

GNN models module is created to host different GNN architecture that can be deployed for training in RLModule class by defining the agent associated with the specific GNN model. The decision about the choice of different architectures are made based on their capabilities and how they can fit well for attributes of machines and operations with edge between as a message passing interface information.
There is no test currently available for this module as it is going to be tested after the agent passed a training on the simple FF neural network.

@NavidCOMSC NavidCOMSC requested a review from Pabloo22 August 14, 2025 18:04
@NavidCOMSC NavidCOMSC self-assigned this Aug 14, 2025
@NavidCOMSC NavidCOMSC added the enhancement New feature or request label Aug 14, 2025
Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @NavidCOMSC, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly expands the project's machine learning capabilities by introducing a dedicated module for Graph Neural Network (GNN) architectures. It integrates four distinct GNN models: Graph Attention Network, Graph Transformer, Hierarchical GNN, and Multi-scale GNN. These additions are crucial for handling complex, graph-structured data, particularly in the context of job shop scheduling, by allowing for more sophisticated modeling of relationships between machines and operations. The new models are designed to be deployable within an existing RLModule class, laying the groundwork for advanced reinforcement learning applications on graph data.

Highlights

  • New GNN Models Module: A new gnn_models module has been introduced to house various Graph Neural Network architectures, providing a structured approach for integrating different GNNs into the system.
  • Graph Attention Network (GAT) Implementation: The pull request adds a GraphAttentionNetwork (GAT) implementation, which uses multi-head attention to process graph data, focusing on relevant operations and machines, and includes global pooling capabilities.
  • Graph Transformer Model: A GraphTransformer model has been added, which applies transformer-style attention to graph nodes, incorporating positional and structural encodings to capture complex relationships within the graph.
  • Hierarchical Graph Neural Network: A HierarchicalGNN is introduced, designed to process graph information at both local (operation-level) and global (job/machine-level) scales, combining these representations through cross-attention for a comprehensive understanding.
  • Multi-scale Graph Neural Network: The MultiScaleGNN is added, enabling the processing of graphs at different levels of coarseness. This model dynamically coarsens and upsamples graph representations, applying different GNN layers at each scale and aggregating the results.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several new GNN architectures. The code is well-structured with docstrings and type hints. However, I've found several critical issues across the new models that would likely prevent them from training correctly or cause them to behave unexpectedly. These include incorrect use of random tensors, unused layers, bugs in data handling for batches, and potential division-by-zero errors. Most critically, the GraphTransformer does not currently use the graph structure in its attention mechanism, and the MultiScaleGNN has bugs in its graph coarsening logic. I've left detailed comments on these issues. Furthermore, the pull request description mentions the lack of tests. Given the complexity of these models, adding a comprehensive test suite is crucial to ensure their correctness and prevent future regressions. I strongly recommend adding unit tests for each model, covering both single graph and batched graph scenarios, before merging.

Comment on lines +135 to +138
struct_encoding = torch.matmul(
adj_encoding,
torch.randn(num_nodes, hidden_dim, device=edge_index.device),
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The structural encoding is created by multiplying the adjacency matrix with a newly generated random tensor (torch.randn(...)) on every forward pass. This is a critical issue as it introduces non-determinism and prevents the model from learning any meaningful representation from the graph structure. The tensor used for structural encoding should be a learnable parameter of the model.

Comment on lines +142 to +149
def _create_attention_mask(
self, edge_index: torch.Tensor, num_nodes: int, batch_size: int
) -> torch.Tensor | None:
"""Create attention mask to respect graph structure."""
# At the first instance allows for the full connected network
# it can be modified to select specific edges or nodes

return None
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The _create_attention_mask method currently returns None, which causes the TransformerEncoder to use full attention over all nodes in the sequence. This completely ignores the graph's connectivity, meaning the model is not functioning as a Graph Transformer. To fix this, you need to create a proper attention mask from the edge_index that restricts attention to only neighboring nodes.

machine_mask = node_types == 1
job_mask = node_types == 2

x_encoded = torch.zeros_like(x[:, : self.hidden_dim])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The initialization of x_encoded using torch.zeros_like(x[:, : self.hidden_dim]) is incorrect and will likely lead to a runtime error. This code slices the input x up to self.hidden_dim, which will fail if x.shape[1] (which is node_dim) is smaller than hidden_dim. The tensor should be initialized with the correct target shape directly.

Suggested change
x_encoded = torch.zeros_like(x[:, : self.hidden_dim])
x_encoded = torch.zeros(x.size(0), self.hidden_dim, device=x.device)

num_nodes = x.size(0)

# Simple node clustering by grouping consecutive nodes
cluster_size = max(1, num_nodes // (num_nodes // coarsening_ratio))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The calculation of cluster_size can lead to a ZeroDivisionError. If num_nodes is less than coarsening_ratio, the expression num_nodes // coarsening_ratio evaluates to 0, causing a division by zero. You should add a check to handle this edge case, for example by not coarsening graphs that are too small.

Suggested change
cluster_size = max(1, num_nodes // (num_nodes // coarsening_ratio))
num_clusters_approx = num_nodes // coarsening_ratio
if num_clusters_approx == 0:
# Graph is too small to coarsen at this scale. A single supernode will be created.
cluster_size = num_nodes
else:
cluster_size = max(1, num_nodes // num_clusters_approx)

coarse_x.append(pooled_features)
coarse_node_idx += 1

coarse_batch = None
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The line coarse_batch = None is incorrectly placed. It is outside the if/else block that handles batched vs. single graphs, causing coarse_batch to always be None after this block executes. This will break the logic for batched graphs in the _upsample_to_original method. This line should be moved inside the else block (e.g., after line 202) so it only applies to the single graph case.

self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)

# Graph structure encoding
self.struct_encoder = nn.Linear(hidden_dim, hidden_dim)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The self.struct_encoder layer is initialized but never used in the forward pass. It seems it was intended to be used in _encode_graph_structure to create learnable structural embeddings. The current implementation uses a random tensor instead, which is incorrect. This layer should either be used or removed.

Comment on lines +44 to +46
self.global_gnn = nn.ModuleList(
[GraphConv(hidden_dim, hidden_dim) for _ in range(num_layers)]
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The self.global_gnn module list is initialized but is never used in the forward pass. The global processing currently only consists of a pooling operation. If a GNN was intended to run on the global graph, it should be applied within the _create_global_graph method or the main forward method. Otherwise, this unused module should be removed.

Comment on lines +70 to +78
if use_global_pool:
if pool_type == "mean":
self.global_pool = global_mean_pool
elif pool_type == "max":
self.global_pool = global_max_pool
elif pool_type == "add":
self.global_pool = global_add_pool
else:
self.global_pool = global_mean_pool
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The if/elif/else chain for selecting the pooling function is a bit verbose and can be simplified for better maintainability by using a dictionary to map pooling type names to their respective functions. This also makes the code more extensible if you decide to add more pooling types in the future.

        if use_global_pool:
            pool_functions = {
                "mean": global_mean_pool,
                "max": global_max_pool,
                "add": global_add_pool,
            }
            self.global_pool = pool_functions.get(pool_type, global_mean_pool)

Comment on lines +81 to +88
if activation == "relu":
self.activation = F.relu
elif activation == "gelu":
self.activation = F.gelu
elif activation == "leaky_relu":
self.activation = F.leaky_relu
else:
self.activation = F.relu
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This if/elif/else structure for selecting the activation function can be made more concise and maintainable by using a dictionary lookup. This pattern is generally preferred over long if/elif chains for mapping keys to values.

        # Activation function
        activation_functions = {
            "relu": F.relu,
            "gelu": F.gelu,
            "leaky_relu": F.leaky_relu,
        }
        self.activation = activation_functions.get(activation, F.relu)

Comment on lines +223 to +228
# Single graph case
attended_local, _ = self.cross_attention(
local_seq, global_seq, global_seq
)
combined = torch.cat([attended_local.squeeze(0), x_local], dim=-1)
return self.final_proj(combined)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code for the single graph case is redundant. The logic is already handled inside the if batch is None: check at the beginning of the function (lines 187-190). This entire block can be removed and its logic should be merged into the if batch is None: block to avoid code duplication and improve clarity.

@NavidCOMSC
Copy link
Copy Markdown
Owner Author

@Pabloo22 I have cleaned up and refactored classes that pass the linting tests apart from the main module imports as they are incomplete at the moment.
I can see Geimini raised some comments, I am waiting for you to have a look and I complete all the remaining changes. Thanks!

@Pabloo22
Copy link
Copy Markdown
Collaborator

Hi Navid,

The main issue is that all these models are for homogenous graphs. The JSSP graph representation uses different types of edges and/or nodes.

@NavidCOMSC
Copy link
Copy Markdown
Owner Author

Hi Pablo (@Pabloo22),

I have reviewed the graphs used in the JSSP in your dissertation. You indicated that both disjunctive and fully-connected type-aware graphs can be used in structuring the JSSP samples. I have spent that many hours building these models for no reason now? You mean that these graphs need to be modified to be applicable for the JSSP, or should they be fully discarded?
What kind of graph should be modelled for the JSSP then? Can you share some examples? Then I can learn what it takes to build the graph structure format for the JSSP. Your previous comment is not constructive at all. By explaining more, it helps me to rethink how to configure a proper graph structure for training in the agent environment and also, all of my time didn't go in vain. Thanks for your understanding.

@Pabloo22
Copy link
Copy Markdown
Collaborator

Pabloo22 commented Aug 14, 2025

I understand that it can be frustrating, but my comment is constructive. I'm telling you what you need to change. I thought you already knew about heterogeneous GNNs (we have discussed about using HeteroData and not Data for example) and, thus, that a quick reminder was sufficient. My dissertation explains what is an heterogeneous/relational GNN (both names are equivalent). In Pytorch, you can build one using HeteroConv. This is what my implementation in gnn_scheduler uses. If I were you, I'd focus only on one architecture (HGAT), but if you want to test more of them is also perfectly valid. I hope this helps. Let me know if you need more info.

@NavidCOMSC
Copy link
Copy Markdown
Owner Author

I do recall the use of Hetreoconv in the gnn_scheduler module. My shortcoming was that I presumed the application of heterogeneous is indicted by the use of GIN for more constrained model application. Now, looking back notes and repo, it becomes more clear. I guess I should make mistakes to learn these intrinsic details.
I have been making necessary changes to convert these classes into heterogeneous version of the GNN, not sure all of their architectures can be adapted to the conversion. I paste another note as soon as I am finished with one at least.

@NavidCOMSC NavidCOMSC closed this Sep 9, 2025
@NavidCOMSC NavidCOMSC reopened this Sep 9, 2025
@NavidCOMSC NavidCOMSC merged commit 481d4d0 into main Sep 9, 2025
4 of 6 checks passed
@NavidCOMSC NavidCOMSC deleted the gnn branch September 9, 2025 19:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants