From f393a7b6c3f380dff86e8d59f5dc7bd239155bea Mon Sep 17 00:00:00 2001 From: Shehtab Zaman Date: Wed, 1 Feb 2023 22:57:11 -0800 Subject: [PATCH 1/5] Add example implementation of data-parallel MGN with random data --- applications/graph/MeshGraphNet/GNN.py | 114 +++++++++ .../graph/MeshGraphNet/GNNComponents.py | 238 ++++++++++++++++++ applications/graph/MeshGraphNet/README.md | 13 + .../graph/MeshGraphNet/SyntheticData.py | 29 +++ applications/graph/MeshGraphNet/Trainer.py | 88 +++++++ 5 files changed, 482 insertions(+) create mode 100644 applications/graph/MeshGraphNet/GNN.py create mode 100644 applications/graph/MeshGraphNet/GNNComponents.py create mode 100644 applications/graph/MeshGraphNet/README.md create mode 100644 applications/graph/MeshGraphNet/SyntheticData.py create mode 100644 applications/graph/MeshGraphNet/Trainer.py diff --git a/applications/graph/MeshGraphNet/GNN.py b/applications/graph/MeshGraphNet/GNN.py new file mode 100644 index 00000000000..0c518e65eed --- /dev/null +++ b/applications/graph/MeshGraphNet/GNN.py @@ -0,0 +1,114 @@ +import lbann +from .GNNComponents import MLP, GraphProcessor + + +def input_data_splitter(input_layer, + num_nodes, + num_edges, + in_dim_node, + in_dim_edge, + out_dim): + """ Takes a flattened sample from the Python DataReader and slices + them according to the graph attributes. + """ + + split_indices = [] + start_index = 0 + node_feature_size = num_nodes * in_dim_node + edge_feature_size = num_edges * in_dim_edge + out_feature_size = num_nodes * out_dim + + split_indices.append(start_index) + split_indices.append(split_indices[-1] + node_feature_size) + split_indices.append(split_indices[-1] + edge_feature_size) + split_indices.append(split_indices[-1] + num_edges) + split_indices.append(split_indices[-1] + num_edges) + split_indices.append(split_indices[-1] + out_feature_size) + + sliced_input = lbann.Slice(input_layer, axis=0, slice_points=split_indices) + + node_features = lbann.Reshape(lbann.Identity(sliced_input), + dims=[num_nodes, in_dim_node]) + edge_features = lbann.Reshape(lbann.Identity(sliced_input), + dims=[num_edges, in_dim_edge]) + source_node_indices = lbann.Reshape(lbann.Identity(sliced_input), dims=[num_edges]) + target_node_indices = lbann.Reshape(lbann.Identity(sliced_input), dims=[num_edges]) + + out_features = lbann.Reshape(lbann.Identity(sliced_input), + dims=[num_nodes, out_dim]) + + return node_features, edge_features, source_node_indices, \ + target_node_indices, out_features + + +def LBANN_GNN_Model(num_nodes, num_edges, + in_dim_node, in_dim_edge, out_dim, + out_dim_node=128, out_dim_edge=128, + hidden_dim_node=128, hidden_dim_edge=128, + hidden_layers_node=2, hidden_layers_edge=2, + mp_iterations=15, + hidden_dim_processor_node=128, hidden_dim_processor_edge=128, + hidden_layers_processor_node=2, hidden_layers_processor_edge=2, + norm_type=lbann.LayerNorm, + hidden_dim_decoder=128, hidden_layers_decoder=2, + num_epochs=10): + + # Set up model modules and associated weights + + node_encoder = MLP(in_dim=in_dim_node, out_dim=out_dim_node, + hidden_dim=hidden_dim_node, hidden_layers=hidden_layers_node, + norm_type=norm_type, name="graph_input_node_encoder") + + edge_encoder = MLP(in_dim=in_dim_edge, out_dim=out_dim_edge, + hidden_dim=hidden_dim_edge, hidden_layers=hidden_layers_edge, + norm_type=norm_type, name="graph_input_edge_encoder") + + graph_processor = GraphProcessor(num_nodes=num_nodes, + mp_iterations=mp_iterations, + in_dim_node=out_dim_node, in_dim_edge=out_dim_edge, + hidden_dim_node=hidden_dim_processor_node, + hidden_dim_edge=hidden_dim_processor_edge, + hidden_layers_node=hidden_layers_processor_node, + hidden_layers_edge=hidden_layers_processor_edge, + norm_type=norm_type) + + node_decoder = MLP(in_dim=out_dim_node, out_dim=out_dim, + hidden_dim=hidden_dim_decoder, hidden_layers=hidden_layers_decoder, + norm_type=None, name="graph_input_node_decoder") + + # Define LBANN Compute graph + + input_layer = lbann.Input(data_field='samples') + + node_features, edge_features, source_node_indices, target_node_indices,\ + out_features = input_data_splitter(input_layer, + num_nodes, + num_edges, + in_dim_node, + in_dim_edge, + out_dim) + + node_features = node_encoder(node_features) + edge_features = edge_encoder(edge_features) + + node_features, _ = graph_processor(node_features, edge_features, + source_node_indices, target_node_indices) + + calculated_features = node_decoder(node_features) + + loss = loss.MeanSquaredError(calculated_features, out_features) + + # Define some of the usual callbacks + + training_output = lbann.CallbackPrint(interval=1, + print_global_stat_only=False) + gpu_usage = lbann.CallbackGPUMemoryUsage() + timer = lbann.CallbackTimer() + callbacks = [training_output, gpu_usage, timer] + + # Putting it all together and compile the model + + layers = lbann.traverse_layer_graph(input_layer) + model = lbann.Model(num_epochs, layers=layers, objective_function=loss, + callbacks=callbacks) + return model diff --git a/applications/graph/MeshGraphNet/GNNComponents.py b/applications/graph/MeshGraphNet/GNNComponents.py new file mode 100644 index 00000000000..b50f62e6b3f --- /dev/null +++ b/applications/graph/MeshGraphNet/GNNComponents.py @@ -0,0 +1,238 @@ +import lbann +from lbann.modules import Module, ChannelwiseFullyConnectedModule + + +class MLP(Module): + """ + Applies channelwise MLP with ReLU activation with Layer Normalization + with a specified number of hidden layers + """ + global_count = 0 + + def __init__(self, + in_dim, + out_dim, + hidden_dim, + hidden_layers, + norm_type=lbann.LayerNorm, + name=None): + + super().__init__() + MLP.global_count += 1 + + self.instance = 0 + self.in_dim = in_dim + self.out_dim = out_dim + self.hidden_dim = hidden_dim + self.hidden_layers = hidden_layers + + self.name = (name if name + else f'MLP_{MLP.global_count}') + + self.layers = [ChannelwiseFullyConnectedModule(hidden_dim, + bias=True, + activation=lbann.Relu)] + for i in range(hidden_layers): + # Total number of MLPs is hidden layers + 2 (input and output) + self.layers.append(ChannelwiseFullyConnectedModule(hidden_dim, + bias=True, + activation=lbann.Relu)) + + self.layers.append(ChannelwiseFullyConnectedModule(out_dim, + bias=True, + activation=None)) + + self.norm_type = None + + if norm_type: + if isinstance(norm_type, type): + self.norm_type = norm_type + else: + self.norm_type = type(norm_type) + + if not issubclass(norm_type, lbann.Layer): + raise ValueError("Normalization must be a layer") + + def forward(self, x): + """ + Args: + x (Layer) : Expected shape (Batch, N, self.in_dim) + + Returns: + (Layer): Expected shape (Batch, N, self.out_dim) + """ + self.instance += 1 + name = f"{self.name}_instance_{self.instance}" + + for layer in self.layers: + x = layer(x) + + if self.norm_type: + return self.norm_type(x) + return x + + +class EdgeProcessor(Module): + """ Applies MLP transform on concatenated node and edge features + """ + global_count = 0 + + + def __init__(self, + in_dim_node=128, + in_dim_edge=128, + hidden_dim=128, + hidden_layers=2, + norm_type=lbann.LayerNorm, + name=None): + super().__init__() + self.instance = 0 + self.name = (name if name + else f'EdgeProcessor_{EdgeProcessor.global_count}') + + self.edge_mlp = MLP(2 * in_dim_node + in_dim_edge, + in_dim_edge, + hidden_dim=hidden_dim, + hidden_layers=hidden_layers, + norm_type=norm_type, + name=f"{self.name}_edge_mlp") + def forward(self, + node_features, + edge_features, + source_node_indices, + target_node_indices,): + """ + Args: + node_features (Layer) : Expected shape (Batch, num_nodes, self.in_dim_node) + edge_features (Layer) : Expected shape (Batch, num_edges, self.in_dim_edge) + source_node_indices (Layer) : Expected shape (Batch, num_edges) + target_node_indices (Layer) : Expected shape (Batch, num_edges) + + Returns: + (Layer): Expected shape (Batch, Num_edges, self.in_dim_edge) + """ + self.instance += 1 + source_node_features = lbann.Gather(node_features, source_node_indices, axis=0) + target_node_features = lbann.Gather(node_features, target_node_indices, axis=0) + + x = lbann.Concatenation([source_node_features, target_node_features, edge_features], + axis=1, + name=f"{self.name}_{self.instance}_concat_features") + x = self.edge_mlp(x) + + return lbann.Sum(edge_features, x, + name=f"{self.name}_{self.instance}_residual_sum") + + +class NodeProcessor(Module): + """ Applies MLP transform on scatter-summed edge features and node features + """ + global_count = 0 + + + def __init__(self, + num_nodes, + in_dim_node=128, + in_dim_edge=128, + hidden_dim=128, + hidden_layers=2, + norm_type=lbann.LayerNorm, + name=None): + super().__init__() + self.instance = 0 + self.name = (name if name + else f'NodeProcessor_{NodeProcessor.global_count}') + self.num_nodes = num_nodes + self.in_dim_edge = in_dim_edge + self.node_mlp = MLP(in_dim_node + in_dim_edge, + in_dim_node, + hidden_dim=hidden_dim, + hidden_layers=hidden_layers, + norm_type=norm_type, + name=f"{self.name}_node_mlp") + + def forward(self, + node_features, + edge_features, + target_edge_indices): + """ + Args: + node_features (Layer) : Expected shape (Batch, num_nodes, self.in_dim_node) + edge_features (Layer) : Expected shape (Batch, Num_edges, self.in_dim_edge) + edge_indices (Layer): Expected shape (Batch, Num_edges) + Returns: + (Layer): Expected shape (Batch, Num_nodes, self.in_dim_node) + """ + self.instance += 1 + + edge_feature_sum = lbann.Scatter(edge_features, target_edge_indices, + name="f{self.name}_{self.instance}_scatter", + dims=[self.num_nodes, self.in_dim_edge], + axis=0) + + x = lbann.Concatenation([node_features, edge_feature_sum], + axis=1, + name=f"{self.name}_{self.instance}_concat_features") + x = self.node_mlp(x) + + return lbann.Sum(edge_features, x, + name=f"{self.name}_{self.instance}_residual_sum") + + +class GraphProcessor(Module): + """ Graph processor module + """ + + def __init__(self, + num_nodes, + mp_iterations=15, + in_dim_node=128, in_dim_edge=128, + hidden_dim_node=128, hidden_dim_edge=128, + hidden_layers_node=2, hidden_layers_edge=2, + norm_type=lbann.LayerNorm): + super().__init__() + + self.blocks = [] + + for _ in range(mp_iterations): + node_processor = NodeProcessor(num_nodes=num_nodes, + in_dim_node=in_dim_node, + in_dim_edge=in_dim_edge, + hidden_dim=hidden_dim_node, + hidden_layers=hidden_layers_node, + norm_type=norm_type) + + edge_processor = EdgeProcessor(in_dim_node=in_dim_node, + in_dim_edge=in_dim_edge, + hidden_dim=hidden_dim_edge, + hidden_layers=hidden_layers_edge, + norm_type=norm_type) + + self.blocks.append((node_processor, edge_processor)) + + def forward(self, + node_features, + edge_features, + source_node_indices, + target_node_indices): + """ + Args: + node_features (Layer) : Expected shape (Batch, num_nodes, self.in_dim_node) + edge_features (Layer) : Expected shape (Batch, Num_edges, self.in_dim_edge) + source_node_indices (Layer) : Expected shape (Batch, num_edges) + target_node_indices (Layer) : Expected shape (Batch, num_edges) + Returns: + (Layer, Layer): Expected shape (Batch, Num_nodes, self.in_dim_node) and + (Batch, num_edges, self.in_dim_edge) + """ + + for node_processor, edge_processor in self.blocks: + x = node_processor(node_features, edge_features, target_node_indices) + e = edge_processor(node_features, edge_features, + source_node_indices, + target_node_indices) + + node_features = x + edge_features = e + + return node_features, edge_features diff --git a/applications/graph/MeshGraphNet/README.md b/applications/graph/MeshGraphNet/README.md new file mode 100644 index 00000000000..9b83b535cbf --- /dev/null +++ b/applications/graph/MeshGraphNet/README.md @@ -0,0 +1,13 @@ +## Mesh Graph Networks + +This example contains LBANN implementation of mesh-based graph neural network with +syntheticly generated data. + +--- +### Running the example + +The data-parallel model can be run with the synthetic data with: + +```bash +python Trainer.py --mini-batch-size --num-epochs +``` diff --git a/applications/graph/MeshGraphNet/SyntheticData.py b/applications/graph/MeshGraphNet/SyntheticData.py new file mode 100644 index 00000000000..2538c210fec --- /dev/null +++ b/applications/graph/MeshGraphNet/SyntheticData.py @@ -0,0 +1,29 @@ +import numpy as np + + +NUM_SAMPLES = 10000 +NUM_NODES = 100 +NUM_EDGES = 1000 +NODE_FEATS = 5 +EDGE_FEATS = 3 +OUT_FEATS = 3 + +NODE_FEATURE_SIZE = NUM_NODES * NODE_FEATS +EDGE_FEATURE_SIZE = NUM_EDGES * EDGE_FEATS +OUT_FEATURE_SIZE = NUM_EDGES * OUT_FEATS + +def get_sample_func(index): + random_features = np.random.random(NODE_FEATURE_SIZE+OUT_FEATURE_SIZE).astype(np.float32) + source_indices = np.random.randint(-1, NUM_NODES, size=NUM_EDGES).astype(np.float32) + target_indices = np.random.randint(-1, NUM_NODES, size=NUM_EDGES).astype(np.float32) + out_features = np.random.random(EDGE_FEATURE_SIZE).astype(np.float32) + + return np.concatenate([random_features, source_indices, target_indices, out_features]) + +def num_samples_func(): + return NUM_SAMPLES + +def sample_dims_func(): + + size = NODE_FEATURE_SIZE + EDGE_FEATURE_SIZE + OUT_FEATURE_SIZE + 2 * NUM_EDGES + return (size, ) diff --git a/applications/graph/MeshGraphNet/Trainer.py b/applications/graph/MeshGraphNet/Trainer.py new file mode 100644 index 00000000000..6c657172b74 --- /dev/null +++ b/applications/graph/MeshGraphNet/Trainer.py @@ -0,0 +1,88 @@ +import lbann +import lbann.contrib.launcher +import lbann.contrib.args +import argparse + +import os.path as osp +from .GNN import LBANN_GNN_Model + +data_dir = osp.dirname(osp.realpath(__file__)) + + +desc = ("Training a Mesh Graph Neural Network Model Using LBANN") + +parser = argparse.ArgumentParser(description=desc) + +lbann.contrib.args.add_scheduler_arguments(parser) +lbann.contrib.args.add_optimizer_arguments(parser) + +parser.add_argument( + '--num-epochs', action='store', default=3, type=int, + help='number of epochs (deafult: 3)', metavar='NUM') + +parser.add_argument( + '--mini-batch-size', action='store', default=256, type=int, + help="mini-batch size (default: 256)", metavar='NUM') + +parser.add_argument( + '--job-name', action='store', default="MGN", type=str, + help="Job name for scheduler", metavar='NAME') + +args = parser.parse_args() +kwargs = lbann.contrib.args.get_scheduler_kwargs(args) + +# Some training parameters + +MINI_BATCH_SIZE = args.mini_batch_size +NUM_EPOCHS = args.num_epochs +JOB_NAME = args.job_name + +# Some synthetic attributes to get the model running + +NUM_NODES = 100 +NUM_EDGES = 1000 +NODE_FEATS = 5 +EDGE_FEATS = 3 +OUT_FEATS = 3 + +def make_data_reader(classname, + sample='get_sample_func', + num_samples='num_samples_func', + sample_dims='sample_dims_func', + validation_percent=0.1): + reader = lbann.reader_pb2.DataReader() + _reader = reader.reader.add() + _reader.name = 'python' + _reader.role = 'train' + _reader.shuffle = True + _reader.percent_of_data_to_use = 1.0 + _reader.validation_percent = validation_percent + _reader.python.module = classname + _reader.python.module_dir = data_dir + _reader.python.sample_function = sample + _reader.python.num_samples_function = num_samples + _reader.python.sample_dims_function = sample_dims + return reader + +def main(): + # Use the defaults for the other parameters + model = LBANN_GNN_Model(num_nodes=NUM_NODES, + num_edges=NUM_EDGES, + in_dim_node=NODE_FEATS, + in_dim_edge=EDGE_FEATS, + out_dim=OUT_FEATS, + num_epochs=NUM_EPOCHS) + + optimizer = lbann.SGD(learn_rate=1e-4) + data_reader = make_data_reader("SyntheticData") + trainer = lbann.Trainer(mini_batch_size=MINI_BATCH_SIZE) + + lbann.contrib.launcher.run(trainer, + model, + data_reader, + optimizer, + job_name=JOB_NAME, + **kwargs) + +if __name__ == '__main__': + main() \ No newline at end of file From 5ff73fba5c74d0d15329130959270d161a7ae2f5 Mon Sep 17 00:00:00 2001 From: Shehtab Zaman Date: Tue, 7 Mar 2023 10:47:05 -0800 Subject: [PATCH 2/5] Updates to the MGN synthetic data reader --- applications/graph/MeshGraphNet/GNN.py | 4 ++-- .../graph/MeshGraphNet/GNNComponents.py | 6 +++-- .../graph/MeshGraphNet/SyntheticData.py | 21 +++++++++++++----- applications/graph/MeshGraphNet/Trainer.py | 22 ++++++++++--------- 4 files changed, 33 insertions(+), 20 deletions(-) diff --git a/applications/graph/MeshGraphNet/GNN.py b/applications/graph/MeshGraphNet/GNN.py index 0c518e65eed..ad5235ff0ea 100644 --- a/applications/graph/MeshGraphNet/GNN.py +++ b/applications/graph/MeshGraphNet/GNN.py @@ -1,5 +1,5 @@ import lbann -from .GNNComponents import MLP, GraphProcessor +from GNNComponents import MLP, GraphProcessor def input_data_splitter(input_layer, @@ -96,7 +96,7 @@ def LBANN_GNN_Model(num_nodes, num_edges, calculated_features = node_decoder(node_features) - loss = loss.MeanSquaredError(calculated_features, out_features) + loss = lbann.MeanSquaredError(calculated_features, out_features) # Define some of the usual callbacks diff --git a/applications/graph/MeshGraphNet/GNNComponents.py b/applications/graph/MeshGraphNet/GNNComponents.py index b50f62e6b3f..ced2f6082db 100644 --- a/applications/graph/MeshGraphNet/GNNComponents.py +++ b/applications/graph/MeshGraphNet/GNNComponents.py @@ -86,6 +86,7 @@ def __init__(self, norm_type=lbann.LayerNorm, name=None): super().__init__() + EdgeProcessor.global_count += 1 self.instance = 0 self.name = (name if name else f'EdgeProcessor_{EdgeProcessor.global_count}') @@ -139,6 +140,7 @@ def __init__(self, norm_type=lbann.LayerNorm, name=None): super().__init__() + NodeProcessor.global_count += 1 self.instance = 0 self.name = (name if name else f'NodeProcessor_{NodeProcessor.global_count}') @@ -166,7 +168,7 @@ def forward(self, self.instance += 1 edge_feature_sum = lbann.Scatter(edge_features, target_edge_indices, - name="f{self.name}_{self.instance}_scatter", + name=f"{self.name}_{self.instance}_scatter", dims=[self.num_nodes, self.in_dim_edge], axis=0) @@ -175,7 +177,7 @@ def forward(self, name=f"{self.name}_{self.instance}_concat_features") x = self.node_mlp(x) - return lbann.Sum(edge_features, x, + return lbann.Sum(node_features, x, name=f"{self.name}_{self.instance}_residual_sum") diff --git a/applications/graph/MeshGraphNet/SyntheticData.py b/applications/graph/MeshGraphNet/SyntheticData.py index 2538c210fec..9f325381825 100644 --- a/applications/graph/MeshGraphNet/SyntheticData.py +++ b/applications/graph/MeshGraphNet/SyntheticData.py @@ -1,12 +1,17 @@ import numpy as np +import configparser + + +DATA_CONFIG = configparser.ConfigParser() +DATA_CONFIG.read("data_config.ini") +NUM_NODES = 100 # int(DATA_CONFIG['DEFAULT']['NUM_NODES']) +NUM_EDGES = 10000 # int(DATA_CONFIG['DEFAULT']['NUM_EDGES']) +NODE_FEATS = 5 # int(DATA_CONFIG['DEFAULT']['NODE_FEATURES']) +EDGE_FEATS = 3 # int(DATA_CONFIG['DEFAULT']['EDGE_FEATURES']) +OUT_FEATS = 3 # int(DATA_CONFIG['DEFAULT']['OUT_FEATURES']) +NUM_SAMPLES = 100 -NUM_SAMPLES = 10000 -NUM_NODES = 100 -NUM_EDGES = 1000 -NODE_FEATS = 5 -EDGE_FEATS = 3 -OUT_FEATS = 3 NODE_FEATURE_SIZE = NUM_NODES * NODE_FEATS EDGE_FEATURE_SIZE = NUM_EDGES * EDGE_FEATS @@ -27,3 +32,7 @@ def sample_dims_func(): size = NODE_FEATURE_SIZE + EDGE_FEATURE_SIZE + OUT_FEATURE_SIZE + 2 * NUM_EDGES return (size, ) + + +if __name__ == '__main__': + print(NUM_NODES) diff --git a/applications/graph/MeshGraphNet/Trainer.py b/applications/graph/MeshGraphNet/Trainer.py index 6c657172b74..4b829d12137 100644 --- a/applications/graph/MeshGraphNet/Trainer.py +++ b/applications/graph/MeshGraphNet/Trainer.py @@ -2,9 +2,9 @@ import lbann.contrib.launcher import lbann.contrib.args import argparse - +import configparser import os.path as osp -from .GNN import LBANN_GNN_Model +from GNN import LBANN_GNN_Model data_dir = osp.dirname(osp.realpath(__file__)) @@ -21,8 +21,8 @@ help='number of epochs (deafult: 3)', metavar='NUM') parser.add_argument( - '--mini-batch-size', action='store', default=256, type=int, - help="mini-batch size (default: 256)", metavar='NUM') + '--mini-batch-size', action='store', default=4, type=int, + help="mini-batch size (default: 4)", metavar='NUM') parser.add_argument( '--job-name', action='store', default="MGN", type=str, @@ -38,12 +38,14 @@ JOB_NAME = args.job_name # Some synthetic attributes to get the model running +DATA_CONFIG = configparser.ConfigParser() +DATA_CONFIG.read("data_config.ini") -NUM_NODES = 100 -NUM_EDGES = 1000 -NODE_FEATS = 5 -EDGE_FEATS = 3 -OUT_FEATS = 3 +NUM_NODES = int(DATA_CONFIG['DEFAULT']['NUM_NODES']) +NUM_EDGES = int(DATA_CONFIG['DEFAULT']['NUM_EDGES']) +NODE_FEATS = int(DATA_CONFIG['DEFAULT']['NODE_FEATURES']) +EDGE_FEATS = int(DATA_CONFIG['DEFAULT']['EDGE_FEATURES']) +OUT_FEATS = int(DATA_CONFIG['DEFAULT']['OUT_FEATURES']) def make_data_reader(classname, sample='get_sample_func', @@ -85,4 +87,4 @@ def main(): **kwargs) if __name__ == '__main__': - main() \ No newline at end of file + main() From 4df7c8885570aaa1e1956f7314c1c2c267dfc870 Mon Sep 17 00:00:00 2001 From: Shehtab Zaman Date: Sun, 26 Mar 2023 16:14:34 -0400 Subject: [PATCH 3/5] Apply suggestions from code review Co-authored-by: Tal Ben-Nun --- applications/graph/MeshGraphNet/README.md | 5 +++-- applications/graph/MeshGraphNet/SyntheticData.py | 2 -- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/applications/graph/MeshGraphNet/README.md b/applications/graph/MeshGraphNet/README.md index 9b83b535cbf..4c9e5ac3108 100644 --- a/applications/graph/MeshGraphNet/README.md +++ b/applications/graph/MeshGraphNet/README.md @@ -1,7 +1,8 @@ ## Mesh Graph Networks -This example contains LBANN implementation of mesh-based graph neural network with -syntheticly generated data. +This example contains an LBANN implementation of mesh-based graph neural network (MeshGraphNet) with +synthetically generated data. +For more information about the model, refer to: T. Pfaff et al., "Learning Mesh-Based Simulation with Graph Networks". ICLR'21. --- ### Running the example diff --git a/applications/graph/MeshGraphNet/SyntheticData.py b/applications/graph/MeshGraphNet/SyntheticData.py index 9f325381825..47283f0eb13 100644 --- a/applications/graph/MeshGraphNet/SyntheticData.py +++ b/applications/graph/MeshGraphNet/SyntheticData.py @@ -34,5 +34,3 @@ def sample_dims_func(): return (size, ) -if __name__ == '__main__': - print(NUM_NODES) From bb106aff13ce0c31f4910129336cbc72402340d3 Mon Sep 17 00:00:00 2001 From: Shehtab Zaman Date: Mon, 4 Dec 2023 10:01:35 -0800 Subject: [PATCH 4/5] Adding suggestions from code review Add synthetic config file --- applications/graph/MeshGraphNet/GNN.py | 262 ++++++---- .../graph/MeshGraphNet/GNNComponents.py | 470 +++++++++--------- applications/graph/MeshGraphNet/README.md | 5 + .../graph/MeshGraphNet/data_config.ini | 6 + 4 files changed, 410 insertions(+), 333 deletions(-) create mode 100644 applications/graph/MeshGraphNet/data_config.ini diff --git a/applications/graph/MeshGraphNet/GNN.py b/applications/graph/MeshGraphNet/GNN.py index ad5235ff0ea..0676f843c7b 100644 --- a/applications/graph/MeshGraphNet/GNN.py +++ b/applications/graph/MeshGraphNet/GNN.py @@ -2,113 +2,155 @@ from GNNComponents import MLP, GraphProcessor -def input_data_splitter(input_layer, - num_nodes, - num_edges, - in_dim_node, - in_dim_edge, - out_dim): - """ Takes a flattened sample from the Python DataReader and slices - them according to the graph attributes. - """ - - split_indices = [] - start_index = 0 - node_feature_size = num_nodes * in_dim_node - edge_feature_size = num_edges * in_dim_edge - out_feature_size = num_nodes * out_dim - - split_indices.append(start_index) - split_indices.append(split_indices[-1] + node_feature_size) - split_indices.append(split_indices[-1] + edge_feature_size) - split_indices.append(split_indices[-1] + num_edges) - split_indices.append(split_indices[-1] + num_edges) - split_indices.append(split_indices[-1] + out_feature_size) - - sliced_input = lbann.Slice(input_layer, axis=0, slice_points=split_indices) - - node_features = lbann.Reshape(lbann.Identity(sliced_input), - dims=[num_nodes, in_dim_node]) - edge_features = lbann.Reshape(lbann.Identity(sliced_input), - dims=[num_edges, in_dim_edge]) - source_node_indices = lbann.Reshape(lbann.Identity(sliced_input), dims=[num_edges]) - target_node_indices = lbann.Reshape(lbann.Identity(sliced_input), dims=[num_edges]) - - out_features = lbann.Reshape(lbann.Identity(sliced_input), - dims=[num_nodes, out_dim]) - - return node_features, edge_features, source_node_indices, \ - target_node_indices, out_features - - -def LBANN_GNN_Model(num_nodes, num_edges, - in_dim_node, in_dim_edge, out_dim, - out_dim_node=128, out_dim_edge=128, - hidden_dim_node=128, hidden_dim_edge=128, - hidden_layers_node=2, hidden_layers_edge=2, - mp_iterations=15, - hidden_dim_processor_node=128, hidden_dim_processor_edge=128, - hidden_layers_processor_node=2, hidden_layers_processor_edge=2, - norm_type=lbann.LayerNorm, - hidden_dim_decoder=128, hidden_layers_decoder=2, - num_epochs=10): - - # Set up model modules and associated weights - - node_encoder = MLP(in_dim=in_dim_node, out_dim=out_dim_node, - hidden_dim=hidden_dim_node, hidden_layers=hidden_layers_node, - norm_type=norm_type, name="graph_input_node_encoder") - - edge_encoder = MLP(in_dim=in_dim_edge, out_dim=out_dim_edge, - hidden_dim=hidden_dim_edge, hidden_layers=hidden_layers_edge, - norm_type=norm_type, name="graph_input_edge_encoder") - - graph_processor = GraphProcessor(num_nodes=num_nodes, - mp_iterations=mp_iterations, - in_dim_node=out_dim_node, in_dim_edge=out_dim_edge, - hidden_dim_node=hidden_dim_processor_node, - hidden_dim_edge=hidden_dim_processor_edge, - hidden_layers_node=hidden_layers_processor_node, - hidden_layers_edge=hidden_layers_processor_edge, - norm_type=norm_type) - - node_decoder = MLP(in_dim=out_dim_node, out_dim=out_dim, - hidden_dim=hidden_dim_decoder, hidden_layers=hidden_layers_decoder, - norm_type=None, name="graph_input_node_decoder") - - # Define LBANN Compute graph - - input_layer = lbann.Input(data_field='samples') - - node_features, edge_features, source_node_indices, target_node_indices,\ - out_features = input_data_splitter(input_layer, - num_nodes, - num_edges, - in_dim_node, - in_dim_edge, - out_dim) - - node_features = node_encoder(node_features) - edge_features = edge_encoder(edge_features) - - node_features, _ = graph_processor(node_features, edge_features, - source_node_indices, target_node_indices) - - calculated_features = node_decoder(node_features) - - loss = lbann.MeanSquaredError(calculated_features, out_features) - - # Define some of the usual callbacks - - training_output = lbann.CallbackPrint(interval=1, - print_global_stat_only=False) - gpu_usage = lbann.CallbackGPUMemoryUsage() - timer = lbann.CallbackTimer() - callbacks = [training_output, gpu_usage, timer] - - # Putting it all together and compile the model - - layers = lbann.traverse_layer_graph(input_layer) - model = lbann.Model(num_epochs, layers=layers, objective_function=loss, - callbacks=callbacks) - return model +def input_data_splitter( + input_layer, num_nodes, num_edges, in_dim_node, in_dim_edge, out_dim +): + """Takes a flattened sample from the Python DataReader and slices + them according to the graph attributes. + """ + + split_indices = [] + start_index = 0 + node_feature_size = num_nodes * in_dim_node + edge_feature_size = num_edges * in_dim_edge + out_feature_size = num_nodes * out_dim + + split_indices.append(start_index) + split_indices.append(split_indices[-1] + node_feature_size) + split_indices.append(split_indices[-1] + edge_feature_size) + split_indices.append(split_indices[-1] + num_edges) + split_indices.append(split_indices[-1] + num_edges) + split_indices.append(split_indices[-1] + out_feature_size) + + sliced_input = lbann.Slice(input_layer, axis=0, slice_points=split_indices) + + node_features = lbann.Reshape( + lbann.Identity(sliced_input), dims=[num_nodes, in_dim_node] + ) + edge_features = lbann.Reshape( + lbann.Identity(sliced_input), dims=[num_edges, in_dim_edge] + ) + source_node_indices = lbann.Reshape(lbann.Identity(sliced_input), dims=[num_edges]) + target_node_indices = lbann.Reshape(lbann.Identity(sliced_input), dims=[num_edges]) + + out_features = lbann.Reshape( + lbann.Identity(sliced_input), dims=[num_nodes, out_dim] + ) + + return ( + node_features, + edge_features, + source_node_indices, + target_node_indices, + out_features, + ) + + +def LBANN_GNN_Model( + num_nodes, + num_edges, + in_dim_node, + in_dim_edge, + out_dim, + out_dim_node=128, + out_dim_edge=128, + hidden_dim_node=128, + hidden_dim_edge=128, + hidden_layers_node=2, + hidden_layers_edge=2, + mp_iterations=15, + hidden_dim_processor_node=128, + hidden_dim_processor_edge=128, + hidden_layers_processor_node=2, + hidden_layers_processor_edge=2, + norm_type=lbann.LayerNorm, + hidden_dim_decoder=128, + hidden_layers_decoder=2, + num_epochs=10, +): + # Set up model modules and associated weights + + node_encoder = MLP( + in_dim=in_dim_node, + out_dim=out_dim_node, + hidden_dim=hidden_dim_node, + hidden_layers=hidden_layers_node, + norm_type=norm_type, + name="graph_input_node_encoder", + ) + + edge_encoder = MLP( + in_dim=in_dim_edge, + out_dim=out_dim_edge, + hidden_dim=hidden_dim_edge, + hidden_layers=hidden_layers_edge, + norm_type=norm_type, + name="graph_input_edge_encoder", + ) + + # The graph processor currently only implements homogenous node graphs + # so we do not distinguish between world and mesh nodes. LBANN supports + # heterogenous and multi-graphs in general + + # We also disable adaptive remeshing as that may require recomputing + # the compute graph due to changing graph characteristics + graph_processor = GraphProcessor( + num_nodes=num_nodes, + mp_iterations=mp_iterations, + in_dim_node=out_dim_node, + in_dim_edge=out_dim_edge, + hidden_dim_node=hidden_dim_processor_node, + hidden_dim_edge=hidden_dim_processor_edge, + hidden_layers_node=hidden_layers_processor_node, + hidden_layers_edge=hidden_layers_processor_edge, + norm_type=norm_type, + ) + + node_decoder = MLP( + in_dim=out_dim_node, + out_dim=out_dim, + hidden_dim=hidden_dim_decoder, + hidden_layers=hidden_layers_decoder, + norm_type=None, + name="graph_input_node_decoder", + ) + + # Define LBANN Compute graph + + input_layer = lbann.Input(data_field="samples") + + ( + node_features, + edge_features, + source_node_indices, + target_node_indices, + out_features, + ) = input_data_splitter( + input_layer, num_nodes, num_edges, in_dim_node, in_dim_edge, out_dim + ) + + node_features = node_encoder(node_features) + edge_features = edge_encoder(edge_features) + + node_features, _ = graph_processor( + node_features, edge_features, source_node_indices, target_node_indices + ) + + calculated_features = node_decoder(node_features) + + loss = lbann.MeanSquaredError(calculated_features, out_features) + + # Define some of the usual callbacks + + training_output = lbann.CallbackPrint(interval=1, print_global_stat_only=False) + gpu_usage = lbann.CallbackGPUMemoryUsage() + timer = lbann.CallbackTimer() + callbacks = [training_output, gpu_usage, timer] + + # Putting it all together and compile the model + + layers = lbann.traverse_layer_graph(input_layer) + model = lbann.Model( + num_epochs, layers=layers, objective_function=loss, callbacks=callbacks + ) + return model diff --git a/applications/graph/MeshGraphNet/GNNComponents.py b/applications/graph/MeshGraphNet/GNNComponents.py index ced2f6082db..24c4bede170 100644 --- a/applications/graph/MeshGraphNet/GNNComponents.py +++ b/applications/graph/MeshGraphNet/GNNComponents.py @@ -3,238 +3,262 @@ class MLP(Module): - """ - Applies channelwise MLP with ReLU activation with Layer Normalization - with a specified number of hidden layers - """ - global_count = 0 - - def __init__(self, - in_dim, - out_dim, - hidden_dim, - hidden_layers, - norm_type=lbann.LayerNorm, - name=None): - - super().__init__() - MLP.global_count += 1 - - self.instance = 0 - self.in_dim = in_dim - self.out_dim = out_dim - self.hidden_dim = hidden_dim - self.hidden_layers = hidden_layers - - self.name = (name if name - else f'MLP_{MLP.global_count}') - - self.layers = [ChannelwiseFullyConnectedModule(hidden_dim, - bias=True, - activation=lbann.Relu)] - for i in range(hidden_layers): - # Total number of MLPs is hidden layers + 2 (input and output) - self.layers.append(ChannelwiseFullyConnectedModule(hidden_dim, - bias=True, - activation=lbann.Relu)) - - self.layers.append(ChannelwiseFullyConnectedModule(out_dim, - bias=True, - activation=None)) - - self.norm_type = None - - if norm_type: - if isinstance(norm_type, type): - self.norm_type = norm_type - else: - self.norm_type = type(norm_type) - - if not issubclass(norm_type, lbann.Layer): - raise ValueError("Normalization must be a layer") - - def forward(self, x): """ - Args: - x (Layer) : Expected shape (Batch, N, self.in_dim) - - Returns: - (Layer): Expected shape (Batch, N, self.out_dim) + Applies channelwise MLP with ReLU activation with Layer Normalization + with a specified number of hidden layers """ - self.instance += 1 - name = f"{self.name}_instance_{self.instance}" - for layer in self.layers: - x = layer(x) - - if self.norm_type: - return self.norm_type(x) - return x + global_count = 0 + + def __init__( + self, + in_dim, + out_dim, + hidden_dim, + hidden_layers, + norm_type=lbann.LayerNorm, + name=None, + ): + super().__init__() + MLP.global_count += 1 + + self.instance = 0 + self.in_dim = in_dim + self.out_dim = out_dim + self.hidden_dim = hidden_dim + self.hidden_layers = hidden_layers + + self.name = name if name else f"MLP_{MLP.global_count}" + + self.layers = [ + ChannelwiseFullyConnectedModule( + hidden_dim, bias=True, activation=lbann.Relu + ) + ] + for i in range(hidden_layers): + # Total number of MLPs is hidden layers + 2 (input and output) + self.layers.append( + ChannelwiseFullyConnectedModule( + hidden_dim, bias=True, activation=lbann.Relu + ) + ) + + self.layers.append( + ChannelwiseFullyConnectedModule(out_dim, bias=True, activation=None) + ) + + self.norm_type = None + + if norm_type: + if isinstance(norm_type, type): + self.norm_type = norm_type + else: + self.norm_type = type(norm_type) + + if not issubclass(norm_type, lbann.Layer): + raise ValueError("Normalization must be a layer") + + def forward(self, x): + """ + Args: + x (Layer) : Expected shape (Batch, N, self.in_dim) + + Returns: + (Layer): Expected shape (Batch, N, self.out_dim) + """ + self.instance += 1 + name = f"{self.name}_instance_{self.instance}" + + for layer in self.layers: + x = layer(x) + + if self.norm_type: + return self.norm_type(x, name=name+"_norm") + return x class EdgeProcessor(Module): - """ Applies MLP transform on concatenated node and edge features - """ - global_count = 0 - - - def __init__(self, - in_dim_node=128, - in_dim_edge=128, - hidden_dim=128, - hidden_layers=2, - norm_type=lbann.LayerNorm, - name=None): - super().__init__() - EdgeProcessor.global_count += 1 - self.instance = 0 - self.name = (name if name - else f'EdgeProcessor_{EdgeProcessor.global_count}') - - self.edge_mlp = MLP(2 * in_dim_node + in_dim_edge, - in_dim_edge, - hidden_dim=hidden_dim, - hidden_layers=hidden_layers, - norm_type=norm_type, - name=f"{self.name}_edge_mlp") - def forward(self, - node_features, - edge_features, - source_node_indices, - target_node_indices,): - """ - Args: - node_features (Layer) : Expected shape (Batch, num_nodes, self.in_dim_node) - edge_features (Layer) : Expected shape (Batch, num_edges, self.in_dim_edge) - source_node_indices (Layer) : Expected shape (Batch, num_edges) - target_node_indices (Layer) : Expected shape (Batch, num_edges) - - Returns: - (Layer): Expected shape (Batch, Num_edges, self.in_dim_edge) - """ - self.instance += 1 - source_node_features = lbann.Gather(node_features, source_node_indices, axis=0) - target_node_features = lbann.Gather(node_features, target_node_indices, axis=0) - - x = lbann.Concatenation([source_node_features, target_node_features, edge_features], - axis=1, - name=f"{self.name}_{self.instance}_concat_features") - x = self.edge_mlp(x) - - return lbann.Sum(edge_features, x, - name=f"{self.name}_{self.instance}_residual_sum") + """Applies MLP transform on concatenated node and edge features""" + + global_count = 0 + + def __init__( + self, + in_dim_node=128, + in_dim_edge=128, + hidden_dim=128, + hidden_layers=2, + norm_type=lbann.LayerNorm, + name=None, + ): + super().__init__() + EdgeProcessor.global_count += 1 + self.instance = 0 + self.name = name if name else f"EdgeProcessor_{EdgeProcessor.global_count}" + + self.edge_mlp = MLP( + 2 * in_dim_node + in_dim_edge, + in_dim_edge, + hidden_dim=hidden_dim, + hidden_layers=hidden_layers, + norm_type=norm_type, + name=f"{self.name}_edge_mlp", + ) + + def forward( + self, + node_features, + edge_features, + source_node_indices, + target_node_indices, + ): + """ + Args: + node_features (Layer) : Expected shape (Batch, num_nodes, self.in_dim_node) + edge_features (Layer) : Expected shape (Batch, num_edges, self.in_dim_edge) + source_node_indices (Layer) : Expected shape (Batch, num_edges) + target_node_indices (Layer) : Expected shape (Batch, num_edges) + + Returns: + (Layer): Expected shape (Batch, Num_edges, self.in_dim_edge) + """ + self.instance += 1 + source_node_features = lbann.Gather(node_features, source_node_indices, axis=0) + target_node_features = lbann.Gather(node_features, target_node_indices, axis=0) + + x = lbann.Concatenation( + [source_node_features, target_node_features, edge_features], + axis=1, + name=f"{self.name}_{self.instance}_concat_features", + ) + x = self.edge_mlp(x) + + return lbann.Sum( + edge_features, x, name=f"{self.name}_{self.instance}_residual_sum" + ) class NodeProcessor(Module): - """ Applies MLP transform on scatter-summed edge features and node features - """ - global_count = 0 - - - def __init__(self, - num_nodes, - in_dim_node=128, - in_dim_edge=128, - hidden_dim=128, - hidden_layers=2, - norm_type=lbann.LayerNorm, - name=None): - super().__init__() - NodeProcessor.global_count += 1 - self.instance = 0 - self.name = (name if name - else f'NodeProcessor_{NodeProcessor.global_count}') - self.num_nodes = num_nodes - self.in_dim_edge = in_dim_edge - self.node_mlp = MLP(in_dim_node + in_dim_edge, - in_dim_node, - hidden_dim=hidden_dim, - hidden_layers=hidden_layers, - norm_type=norm_type, - name=f"{self.name}_node_mlp") - - def forward(self, - node_features, - edge_features, - target_edge_indices): - """ - Args: - node_features (Layer) : Expected shape (Batch, num_nodes, self.in_dim_node) - edge_features (Layer) : Expected shape (Batch, Num_edges, self.in_dim_edge) - edge_indices (Layer): Expected shape (Batch, Num_edges) - Returns: - (Layer): Expected shape (Batch, Num_nodes, self.in_dim_node) - """ - self.instance += 1 - - edge_feature_sum = lbann.Scatter(edge_features, target_edge_indices, - name=f"{self.name}_{self.instance}_scatter", - dims=[self.num_nodes, self.in_dim_edge], - axis=0) - - x = lbann.Concatenation([node_features, edge_feature_sum], - axis=1, - name=f"{self.name}_{self.instance}_concat_features") - x = self.node_mlp(x) - - return lbann.Sum(node_features, x, - name=f"{self.name}_{self.instance}_residual_sum") + """Applies MLP transform on scatter-summed edge features and node features""" + + global_count = 0 + + def __init__( + self, + num_nodes, + in_dim_node=128, + in_dim_edge=128, + hidden_dim=128, + hidden_layers=2, + norm_type=lbann.LayerNorm, + name=None, + ): + super().__init__() + NodeProcessor.global_count += 1 + self.instance = 0 + self.name = name if name else f"NodeProcessor_{NodeProcessor.global_count}" + self.num_nodes = num_nodes + self.in_dim_edge = in_dim_edge + self.node_mlp = MLP( + in_dim_node + in_dim_edge, + in_dim_node, + hidden_dim=hidden_dim, + hidden_layers=hidden_layers, + norm_type=norm_type, + name=f"{self.name}_node_mlp", + ) + + def forward(self, node_features, edge_features, target_edge_indices): + """ + Args: + node_features (Layer) : Expected shape (Batch, num_nodes, self.in_dim_node) + edge_features (Layer) : Expected shape (Batch, Num_edges, self.in_dim_edge) + edge_indices (Layer): Expected shape (Batch, Num_edges) + Returns: + (Layer): Expected shape (Batch, Num_nodes, self.in_dim_node) + """ + self.instance += 1 + + edge_feature_sum = lbann.Scatter( + edge_features, + target_edge_indices, + name=f"{self.name}_{self.instance}_scatter", + dims=[self.num_nodes, self.in_dim_edge], + axis=0, + ) + + x = lbann.Concatenation( + [node_features, edge_feature_sum], + axis=1, + name=f"{self.name}_{self.instance}_concat_features", + ) + x = self.node_mlp(x) + + return lbann.Sum( + node_features, x, name=f"{self.name}_{self.instance}_residual_sum" + ) class GraphProcessor(Module): - """ Graph processor module - """ - - def __init__(self, - num_nodes, - mp_iterations=15, - in_dim_node=128, in_dim_edge=128, - hidden_dim_node=128, hidden_dim_edge=128, - hidden_layers_node=2, hidden_layers_edge=2, - norm_type=lbann.LayerNorm): - super().__init__() - - self.blocks = [] - - for _ in range(mp_iterations): - node_processor = NodeProcessor(num_nodes=num_nodes, - in_dim_node=in_dim_node, - in_dim_edge=in_dim_edge, - hidden_dim=hidden_dim_node, - hidden_layers=hidden_layers_node, - norm_type=norm_type) - - edge_processor = EdgeProcessor(in_dim_node=in_dim_node, - in_dim_edge=in_dim_edge, - hidden_dim=hidden_dim_edge, - hidden_layers=hidden_layers_edge, - norm_type=norm_type) - - self.blocks.append((node_processor, edge_processor)) - - def forward(self, - node_features, - edge_features, - source_node_indices, - target_node_indices): - """ - Args: - node_features (Layer) : Expected shape (Batch, num_nodes, self.in_dim_node) - edge_features (Layer) : Expected shape (Batch, Num_edges, self.in_dim_edge) - source_node_indices (Layer) : Expected shape (Batch, num_edges) - target_node_indices (Layer) : Expected shape (Batch, num_edges) - Returns: - (Layer, Layer): Expected shape (Batch, Num_nodes, self.in_dim_node) and - (Batch, num_edges, self.in_dim_edge) - """ - - for node_processor, edge_processor in self.blocks: - x = node_processor(node_features, edge_features, target_node_indices) - e = edge_processor(node_features, edge_features, - source_node_indices, - target_node_indices) - - node_features = x - edge_features = e - - return node_features, edge_features + """Graph processor module""" + + def __init__( + self, + num_nodes, + mp_iterations=15, + in_dim_node=128, + in_dim_edge=128, + hidden_dim_node=128, + hidden_dim_edge=128, + hidden_layers_node=2, + hidden_layers_edge=2, + norm_type=lbann.LayerNorm, + ): + super().__init__() + + self.blocks = [] + + for _ in range(mp_iterations): + node_processor = NodeProcessor( + num_nodes=num_nodes, + in_dim_node=in_dim_node, + in_dim_edge=in_dim_edge, + hidden_dim=hidden_dim_node, + hidden_layers=hidden_layers_node, + norm_type=norm_type, + ) + + edge_processor = EdgeProcessor( + in_dim_node=in_dim_node, + in_dim_edge=in_dim_edge, + hidden_dim=hidden_dim_edge, + hidden_layers=hidden_layers_edge, + norm_type=norm_type, + ) + + self.blocks.append((node_processor, edge_processor)) + + def forward( + self, node_features, edge_features, source_node_indices, target_node_indices + ): + """ + Args: + node_features (Layer) : Expected shape (Batch, num_nodes, self.in_dim_node) + edge_features (Layer) : Expected shape (Batch, Num_edges, self.in_dim_edge) + source_node_indices (Layer) : Expected shape (Batch, num_edges) + target_node_indices (Layer) : Expected shape (Batch, num_edges) + Returns: + (Layer, Layer): Expected shape (Batch, Num_nodes, self.in_dim_node) and + (Batch, num_edges, self.in_dim_edge) + """ + + for node_processor, edge_processor in self.blocks: + e = edge_processor( + node_features, edge_features, source_node_indices, target_node_indices + ) + edge_features = e + x = node_processor(node_features, edge_features, target_node_indices) + + node_features = x + + return node_features, edge_features diff --git a/applications/graph/MeshGraphNet/README.md b/applications/graph/MeshGraphNet/README.md index 4c9e5ac3108..76d771958a5 100644 --- a/applications/graph/MeshGraphNet/README.md +++ b/applications/graph/MeshGraphNet/README.md @@ -12,3 +12,8 @@ The data-parallel model can be run with the synthetic data with: ```bash python Trainer.py --mini-batch-size --num-epochs ``` + +### Notes + +- This implementation does not distinguish between world nodes and mesh nodes +- We do not currently implement adaptive remeshing, as this may require updating the compute graph after each mini-batch \ No newline at end of file diff --git a/applications/graph/MeshGraphNet/data_config.ini b/applications/graph/MeshGraphNet/data_config.ini new file mode 100644 index 00000000000..0d641e03b4a --- /dev/null +++ b/applications/graph/MeshGraphNet/data_config.ini @@ -0,0 +1,6 @@ +[DEFAULT] +NUM_NODES = 100 +NUM_EDGES = 10000 +EDGE_FEATURES = 3 +NODE_FEATURES = 5 +OUT_FEATURES = 3 From 3d065f018ba92576ae18b7de632dfca1da102d49 Mon Sep 17 00:00:00 2001 From: Shehtab Zaman Date: Mon, 4 Dec 2023 10:16:26 -0800 Subject: [PATCH 5/5] Clean up some code smells --- applications/graph/MeshGraphNet/GNNComponents.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/applications/graph/MeshGraphNet/GNNComponents.py b/applications/graph/MeshGraphNet/GNNComponents.py index 24c4bede170..f99440d2a60 100644 --- a/applications/graph/MeshGraphNet/GNNComponents.py +++ b/applications/graph/MeshGraphNet/GNNComponents.py @@ -172,17 +172,17 @@ def forward(self, node_features, edge_features, target_edge_indices): """ Args: node_features (Layer) : Expected shape (Batch, num_nodes, self.in_dim_node) - edge_features (Layer) : Expected shape (Batch, Num_edges, self.in_dim_edge) - edge_indices (Layer): Expected shape (Batch, Num_edges) + edge_features (Layer) : Expected shape (Batch, num_edges, self.in_dim_edge) + edge_indices (Layer): Expected shape (Batch, num_edges) Returns: (Layer): Expected shape (Batch, Num_nodes, self.in_dim_node) """ self.instance += 1 - + name = f"{self.name}_{self.instance}" edge_feature_sum = lbann.Scatter( edge_features, target_edge_indices, - name=f"{self.name}_{self.instance}_scatter", + name=f"{name}_scatter", dims=[self.num_nodes, self.in_dim_edge], axis=0, ) @@ -190,12 +190,12 @@ def forward(self, node_features, edge_features, target_edge_indices): x = lbann.Concatenation( [node_features, edge_feature_sum], axis=1, - name=f"{self.name}_{self.instance}_concat_features", + name=f"{name}_concat_features", ) x = self.node_mlp(x) return lbann.Sum( - node_features, x, name=f"{self.name}_{self.instance}_residual_sum" + node_features, x, name=f"{name}_residual_sum" )