Conversation
There was a problem hiding this comment.
Summary of Changes
Hello @NavidCOMSC, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request significantly expands the project's machine learning capabilities by introducing a dedicated module for Graph Neural Network (GNN) architectures. It integrates four distinct GNN models: Graph Attention Network, Graph Transformer, Hierarchical GNN, and Multi-scale GNN. These additions are crucial for handling complex, graph-structured data, particularly in the context of job shop scheduling, by allowing for more sophisticated modeling of relationships between machines and operations. The new models are designed to be deployable within an existing RLModule class, laying the groundwork for advanced reinforcement learning applications on graph data.
Highlights
- New GNN Models Module: A new
gnn_modelsmodule has been introduced to house various Graph Neural Network architectures, providing a structured approach for integrating different GNNs into the system. - Graph Attention Network (GAT) Implementation: The pull request adds a
GraphAttentionNetwork(GAT) implementation, which uses multi-head attention to process graph data, focusing on relevant operations and machines, and includes global pooling capabilities. - Graph Transformer Model: A
GraphTransformermodel has been added, which applies transformer-style attention to graph nodes, incorporating positional and structural encodings to capture complex relationships within the graph. - Hierarchical Graph Neural Network: A
HierarchicalGNNis introduced, designed to process graph information at both local (operation-level) and global (job/machine-level) scales, combining these representations through cross-attention for a comprehensive understanding. - Multi-scale Graph Neural Network: The
MultiScaleGNNis added, enabling the processing of graphs at different levels of coarseness. This model dynamically coarsens and upsamples graph representations, applying different GNN layers at each scale and aggregating the results.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Code Review
This pull request introduces several new GNN architectures. The code is well-structured with docstrings and type hints. However, I've found several critical issues across the new models that would likely prevent them from training correctly or cause them to behave unexpectedly. These include incorrect use of random tensors, unused layers, bugs in data handling for batches, and potential division-by-zero errors. Most critically, the GraphTransformer does not currently use the graph structure in its attention mechanism, and the MultiScaleGNN has bugs in its graph coarsening logic. I've left detailed comments on these issues. Furthermore, the pull request description mentions the lack of tests. Given the complexity of these models, adding a comprehensive test suite is crucial to ensure their correctness and prevent future regressions. I strongly recommend adding unit tests for each model, covering both single graph and batched graph scenarios, before merging.
| struct_encoding = torch.matmul( | ||
| adj_encoding, | ||
| torch.randn(num_nodes, hidden_dim, device=edge_index.device), | ||
| ) |
There was a problem hiding this comment.
The structural encoding is created by multiplying the adjacency matrix with a newly generated random tensor (torch.randn(...)) on every forward pass. This is a critical issue as it introduces non-determinism and prevents the model from learning any meaningful representation from the graph structure. The tensor used for structural encoding should be a learnable parameter of the model.
| def _create_attention_mask( | ||
| self, edge_index: torch.Tensor, num_nodes: int, batch_size: int | ||
| ) -> torch.Tensor | None: | ||
| """Create attention mask to respect graph structure.""" | ||
| # At the first instance allows for the full connected network | ||
| # it can be modified to select specific edges or nodes | ||
|
|
||
| return None |
There was a problem hiding this comment.
The _create_attention_mask method currently returns None, which causes the TransformerEncoder to use full attention over all nodes in the sequence. This completely ignores the graph's connectivity, meaning the model is not functioning as a Graph Transformer. To fix this, you need to create a proper attention mask from the edge_index that restricts attention to only neighboring nodes.
| machine_mask = node_types == 1 | ||
| job_mask = node_types == 2 | ||
|
|
||
| x_encoded = torch.zeros_like(x[:, : self.hidden_dim]) |
There was a problem hiding this comment.
The initialization of x_encoded using torch.zeros_like(x[:, : self.hidden_dim]) is incorrect and will likely lead to a runtime error. This code slices the input x up to self.hidden_dim, which will fail if x.shape[1] (which is node_dim) is smaller than hidden_dim. The tensor should be initialized with the correct target shape directly.
| x_encoded = torch.zeros_like(x[:, : self.hidden_dim]) | |
| x_encoded = torch.zeros(x.size(0), self.hidden_dim, device=x.device) |
| num_nodes = x.size(0) | ||
|
|
||
| # Simple node clustering by grouping consecutive nodes | ||
| cluster_size = max(1, num_nodes // (num_nodes // coarsening_ratio)) |
There was a problem hiding this comment.
The calculation of cluster_size can lead to a ZeroDivisionError. If num_nodes is less than coarsening_ratio, the expression num_nodes // coarsening_ratio evaluates to 0, causing a division by zero. You should add a check to handle this edge case, for example by not coarsening graphs that are too small.
| cluster_size = max(1, num_nodes // (num_nodes // coarsening_ratio)) | |
| num_clusters_approx = num_nodes // coarsening_ratio | |
| if num_clusters_approx == 0: | |
| # Graph is too small to coarsen at this scale. A single supernode will be created. | |
| cluster_size = num_nodes | |
| else: | |
| cluster_size = max(1, num_nodes // num_clusters_approx) |
| coarse_x.append(pooled_features) | ||
| coarse_node_idx += 1 | ||
|
|
||
| coarse_batch = None |
There was a problem hiding this comment.
The line coarse_batch = None is incorrectly placed. It is outside the if/else block that handles batched vs. single graphs, causing coarse_batch to always be None after this block executes. This will break the logic for batched graphs in the _upsample_to_original method. This line should be moved inside the else block (e.g., after line 202) so it only applies to the single graph case.
| self.transformer = nn.TransformerEncoder(encoder_layer, num_layers) | ||
|
|
||
| # Graph structure encoding | ||
| self.struct_encoder = nn.Linear(hidden_dim, hidden_dim) |
There was a problem hiding this comment.
The self.struct_encoder layer is initialized but never used in the forward pass. It seems it was intended to be used in _encode_graph_structure to create learnable structural embeddings. The current implementation uses a random tensor instead, which is incorrect. This layer should either be used or removed.
| self.global_gnn = nn.ModuleList( | ||
| [GraphConv(hidden_dim, hidden_dim) for _ in range(num_layers)] | ||
| ) |
There was a problem hiding this comment.
The self.global_gnn module list is initialized but is never used in the forward pass. The global processing currently only consists of a pooling operation. If a GNN was intended to run on the global graph, it should be applied within the _create_global_graph method or the main forward method. Otherwise, this unused module should be removed.
| if use_global_pool: | ||
| if pool_type == "mean": | ||
| self.global_pool = global_mean_pool | ||
| elif pool_type == "max": | ||
| self.global_pool = global_max_pool | ||
| elif pool_type == "add": | ||
| self.global_pool = global_add_pool | ||
| else: | ||
| self.global_pool = global_mean_pool |
There was a problem hiding this comment.
The if/elif/else chain for selecting the pooling function is a bit verbose and can be simplified for better maintainability by using a dictionary to map pooling type names to their respective functions. This also makes the code more extensible if you decide to add more pooling types in the future.
if use_global_pool:
pool_functions = {
"mean": global_mean_pool,
"max": global_max_pool,
"add": global_add_pool,
}
self.global_pool = pool_functions.get(pool_type, global_mean_pool)| if activation == "relu": | ||
| self.activation = F.relu | ||
| elif activation == "gelu": | ||
| self.activation = F.gelu | ||
| elif activation == "leaky_relu": | ||
| self.activation = F.leaky_relu | ||
| else: | ||
| self.activation = F.relu |
There was a problem hiding this comment.
This if/elif/else structure for selecting the activation function can be made more concise and maintainable by using a dictionary lookup. This pattern is generally preferred over long if/elif chains for mapping keys to values.
# Activation function
activation_functions = {
"relu": F.relu,
"gelu": F.gelu,
"leaky_relu": F.leaky_relu,
}
self.activation = activation_functions.get(activation, F.relu)| # Single graph case | ||
| attended_local, _ = self.cross_attention( | ||
| local_seq, global_seq, global_seq | ||
| ) | ||
| combined = torch.cat([attended_local.squeeze(0), x_local], dim=-1) | ||
| return self.final_proj(combined) |
There was a problem hiding this comment.
This block of code for the single graph case is redundant. The logic is already handled inside the if batch is None: check at the beginning of the function (lines 187-190). This entire block can be removed and its logic should be merged into the if batch is None: block to avoid code duplication and improve clarity.
|
@Pabloo22 I have cleaned up and refactored classes that pass the linting tests apart from the main module imports as they are incomplete at the moment. |
|
Hi Navid, The main issue is that all these models are for homogenous graphs. The JSSP graph representation uses different types of edges and/or nodes. |
|
Hi Pablo (@Pabloo22), I have reviewed the graphs used in the JSSP in your dissertation. You indicated that both disjunctive and fully-connected type-aware graphs can be used in structuring the JSSP samples. I have spent that many hours building these models for no reason now? You mean that these graphs need to be modified to be applicable for the JSSP, or should they be fully discarded? |
|
I understand that it can be frustrating, but my comment is constructive. I'm telling you what you need to change. I thought you already knew about heterogeneous GNNs (we have discussed about using |
|
I do recall the use of |
GNN models module is created to host different GNN architecture that can be deployed for training in RLModule class by defining the agent associated with the specific GNN model. The decision about the choice of different architectures are made based on their capabilities and how they can fit well for attributes of machines and operations with edge between as a message passing interface information.
There is no test currently available for this module as it is going to be tested after the agent passed a training on the simple FF neural network.