PyG--++ is a minimal library for GNNs based on PyTorch. It is a minimalist version of PyTorch Geometric, but some useful features are added. It is named PyG--++ or PyG Minus Minus Plus Plus.
Packages torch, torch_scatter, torch_sparse should be installed.
Run make to install.
For a graph with
We use the Data class, which wraps a dict, to store graph data. Each key-value pair of the Data class refers to a graph feature.
To make it easier to add new features into a graph, we introduced four special "feature classes":
-
node_feature: like$X$ , they are in the shape(num_nodes * num_node_features), and hascat_dim=0andinc=0 -
edge_feature: like$E$ , they are in the shape(num_edges * num_edge_features), and hascat_dim=0andinc=0 -
graph_feature: like$y$ , they are in the shape(1 * num_graph_features), and hascat_dim=0andinc=0 -
edge_index: likeedge_index, they are in the shape(2 * num_edges), and hascat_dim=1andinc=num_nodes
We store the keys that belong to those four classes in four distinct sets, and treat each of the class specially when calling collate() or separate().
To conveniently add a new tensor-type feature to a Data object, we provide the __set_tensor_attr__() method, which is an extension to __setattr__(), by letting the caller decide whether the feature belongs to the above four "feature classes", or whether the feature needs auto-batching service. When it does need auto-batching, it is also up to the caller whether to create a slicing vector for the feature, or use an existing slicing vector. Moreover, this extension comes with little overhead which is negligible compared with the performance gain due to our simplified architecture of the framework.
Similarly, a __del_tensor_attr__() method is provided to remove such tensor-type features automatically. Therefore, the framework enables fast implementation of any novel preprocessing techniques, without researchers paying too much attention on the underlying storage details.
For efficient training on GPU, we need to combine a bag of graphs into a batch.
The Batch class inherits from Data, and includes three extra fields: batch, ptr and edge_slice
batchmaps indices of nodesito indices of graphsbatch[i]ptrmaps indices of graphsito indices of nodesptr[i]edge_slicemaps indices of graphsito indices of edgesedge_slice[i]
The batching procedure can be applied to a bag of Batch objects in exactly the same way as Data. The result is again a Batch object. In this case, there will be two sets of batch, ptr and edge_slice vectors. We append an integral label 0, 1, ... to distinguish the different sets of them. For example, batch0 maps indices of nodes to indices of "individual" graphs, while batch1 maps indices of nodes to indices of input batches.
The torch_geometric package offers an automatic batching for non-standard graph features (not x, edge_index, edge_attr or y). We also include such mechanism: if an additional feature lies in any of the four specialized "feature classes", an automatic batching procedure is executed; otherwise, we simply collect them into a list. We believe our treatment is general enough to cover many interesting models.
We use a Dataset object, which simply wraps a Batch object, to store a graph dataset. When calling __getitem__ on datasets, we return a graph from the dataset if the index is an integer, and return a "view" of the original dataset if the index is a slicing. This makes it zero-copy if we only want to split the dataset (into train / test, etc).
The torch_geometric package processes the datasets and the batches differently, which makes itself a less unified framework. Our profile test on real-world datasets (QM9) proves that our treatment is both faster (~0.6x time) and more elegant while being a little more storage-consuming (~1.07x disk space).
The DataLoader uses the torch.utils.data.DataLoader class, and is a simple wrapper class. As an extension to the torch_geometric.loader.DataLoader class, our DataLoader allows passing a user-defined collator function as an argument, which overrides the default collate function.
The MessagePassing class offers a handy way to define graph convolutional operators. To define a MPNN layer, one only needs to implement message() and update() methods (and forward(), optionally, though we have offered a default implementation).