Skip to content

Latest commit

 

History

History
48 lines (32 loc) · 1.16 KB

File metadata and controls

48 lines (32 loc) · 1.16 KB

Graph Convolutional Neural Networks

Unofficial Reimplementation of "Semi-Supervised Classification with Graph Convolutional Networks"12 in PyTorch.

Usage

Dense

input = ...
adj = ...

n_nodes = adj.size(0)  # or n_nodes = input.size(1)
n_features = input.size(-1)
h1_features = ...
h2_features = ...

conv1 = nn.Sequential(LinearGraphConv(n_features, h1_features), nn.ReLU())
conv2 = nn.Sequential(LinearGraphConv(h1_features, h2_features), nn.ReLU())

output = conv2((conv1((input, adj)), adj))

Sparse

input = ...
adj_sparse_coo = ...  

n_nodes = adj.size(0)  # or n_nodes = input.size(1)
n_features = input.size(-1)
h1_features = ...
h2_features = ...

conv1 = nn.Sequential(SparseLinearGraphConv(n_features, h1_features), nn.ReLU())
conv2 = nn.Sequential(SparseLinearGraphConv(h1_features, h2_features), nn.ReLU())

output = conv2((conv1((input, adj_sparse_coo)), adj_sparse_coo))

TODOs

Add Cora dataset example.

References

Footnotes

  1. Kipf & Welling, Semi-Supervised Classification with Graph Convolutional Networks, 2016.

  2. Official Implementation.