Skip to main content

What are Graph Isomorphism Networks?

This article provides a brief overview of Graph Isomorphism Networks (GIN), complete with code examples in PyTorch Geometric and interactive visualizations using W&B.
Created on August 10|Last edited on June 28

An Introduction to Graph Isomorphism Networks

In this article we'll briefly go over Graph Isomorphism Network (GIN) proposed in How Powerful are Graph Neural Networks by Xu et al, a fundamental model in the message passing paradigm based on the Weisfeiler-Lehman (WL) graph isomorphism test.
Message-passing paradigms are one of the three major paradigms for Graph Neural Networks (GNNs) alongside Convolutional and Attentional Networks. For a brief overview of the three paradigms, you can refer to the following posts:


Table of Contents



The Method Behind Graph Isomorphism Networks

The key novelty behind this approach is that it enables us to differentiate graphs that are not isomorphic to each other. Simply put, graph isomorphism is an equivalent relation for "similar" structures.
Graph isomorphism is a pure topological relation that doesn't consider node features.
💡
As outlined in our fundamental article on messaging passing, the key distinguishing factor between most message-passing frameworks is the aggregation scheme. The most general description of this is:
hu=ϕ(xu,vNuψ(xu,xv))\huge h_u = \phi(x_u, \, \oplus_{v \in \mathcal{N}_u} \, \psi(x_u, x_v))

In the case of graph isomorphism networks, the update function is of the following form:
hu=ϕ((1+ϵ)xu+vNvxv)\huge h_u = \phi ((1 + \epsilon ) \cdot x_u \, + \, \displaystyle \sum_{v \in \mathcal{N}_v} x_v)


Implementing the Model with PyTorch Geometric

PyTorch Geometric provides a great implementation of the update rule outlined in the paper (GINConv).
The other two steps in the message-passing paradigm are initialization and Transformation. The initialization step is simple: We encode our graph features using a MLP.
For the transformation step, we concatenate graph representations across all the layers.
Let's walk through an example implementation, shall we?
from torch_geometric.nn import GINConv, global_mean_pool

class GIN(torch.nn.Module):
def __init__(self, dataset, num_layers, hidden):
super().__init__()
## Initialization Step
self.initialization = GINConv(
Sequential(
Linear(dataset.num_features, hidden),
ReLU(),
Linear(hidden, hidden),
ReLU(),
BN(hidden),
),
eps = 0.,
train_eps=False)
## Aggregation Layers
self.mp_layers = torch.nn.ModuleList()
for i in range(num_layers - 1):
self.mp_layers.append(
GINConv(
Sequential(
Linear(hidden, hidden),
ReLU(),
Linear(hidden, hidden),
ReLU(),
BN(hidden),
),
eps=0.,
train_eps=False)
)
self.lin1 = Linear(hidden, hidden)
self.lin2 = Linear(hidden, dataset.num_classes)

def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
x = self.initialization(x, edge_index)
for conv in self.mp_layers:
x = conv(x, edge_index)
x = global_mean_pool(x, batch)
## Classification Head
x = F.relu(self.lin1(x))
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin2(x)
return F.log_softmax(x, dim=-1)

The Results

We trained our models for 50 epochs to perform graph classification on the proteins dataset, using 5 layers. Here's our training loss and accuracy:

Run set
2


Summary

In this article, we learned about the Graph Isomorphism Network, a popular method based on the Message-Passing Framework, along with code and interactive visualizations.
If you want more reports covering graph neural networks with code implementations, let us know in the comments below or on our forum ✨!
Check out these other reports on Fully Connected covering other Graph Neural Networks-based topics and ideas.
To see the full suite of W&B features, please check out this short 5 minutes guide.

Iterate on AI agents and models faster. Try Weights & Biases today.