Skip to main content

An Introduction to GraphSAGE

This article provides an overview of the GraphSAGE neural network architecture, complete with code examples in PyTorch Geometric, and visualizations using W&B.
Created on August 10|Last edited on June 28
In this article, we'll dive into GraphSAGE, a general, inductive framework that leverages node feature information proposed in the paper Inductive Representation Learning on Large Graphs by William L. Hamilton, Rex Ying, and Jure Leskovec.
Here's what we'll be covering today:

Table of Contents




Let's dive in!

What is GraphSAGE?

GraphSAGE is a fundamental model from the graph convolutional network paradigm inspired by the Weisfeiler-Lehman (WL) graph isomorphism test. It's a popular framework for inductive representation learning on large graphs and is useful for graphs that have rich node attribute information.
There are three main classes of models of graph neural networks, namely message-passing graph neural networks, graph convolutional networks, and graph attention networks. For a brief overview of the three paradigms, you can refer to the following blogs:


The Method for Using GraphSAGE

The problem being studied in this paper is the task of generating low-dimensional vector embeddings for the various nodes in the dataset (it's worth noting that real-life tasks involve dynamic graphs, i.e. there is a possibility of encountering unseen nodes).
The key novelty behind this approach is a neighborhood sampling step, which solved a major problem of its predecessor Deep Walk. Each node of the graph is represented as some aggregation of its neighbors. Therefore, even if a new unseen node is encountered, we can represent it as some aggregation of its neighborhood.
As outlined in our fundamental article on graph convolutional networks, the most general way of expressing convolutional graph nets is as follows:
hv=ϕ(xu,vNucuvψ(xv))\huge h_v = \phi(x_u, \, \, \oplus_{v \in \mathcal{N}_u} c_{uv} \,\, \psi(x_v))

where cuv\large c_{uv} represents the convolutional operation between nodes u\large u and v\large v. In the case of GraphSAGE, the update function is of the following form:
hu=(Wxu+W1duvNuxv)\huge h_u = (Wx_u + W' \displaystyle \frac{1}{d_u} \sum_{v \in \mathcal{N}_u} x_v)

where W\large W and W\large W' are distinct weight matrices.

Implementing GraphSAGE

PyTorch Geometric provides a great implementation of the update rule outlined in the paper (SAGEConv).
Let's walk through an example implementation!
class GraphSAGE(torch.nn.Module):
def __init__(self, dataset, num_layers, hidden):
super().__init__()
self.conv1 = SAGEConv(dataset.num_features, hidden)
self.convs = torch.nn.ModuleList()
for i in range(num_layers - 1):
self.convs.append(SAGEConv(hidden, hidden))
self.lin1 = Linear(hidden, hidden)
self.lin2 = Linear(hidden, dataset.num_classes)

def reset_parameters(self):
self.conv1.reset_parameters()
for conv in self.convs:
conv.reset_parameters()
self.lin1.reset_parameters()
self.lin2.reset_parameters()

def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
x = F.relu(self.conv1(x, edge_index))
for conv in self.convs:
x = F.relu(conv(x, edge_index))
x = global_mean_pool(x, batch)
x = F.relu(self.lin1(x))
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin2(x)
return F.log_softmax(x, dim=-1)

GraphSAGE: The Results

We train some models for 50 epochs to perform graph classification on the REDDIT-BINARY Dataset, using 5 layers, and report the training loss and accuracy.

Run set
3


Summary

In this article, we learned about GraphSAGE, a popular graph neural network architecture, along with code and interactive visualizations. To see the full suite of W&B features, please check out this short 5-minute guide.
If you want more reports covering graph neural networks with code implementations, let us know in the comments below or on our forum ✨!
Check out these other reports on Fully Connected covering other Graph Neural Networks-based topics and ideas.

Iterate on AI agents and models faster. Try Weights & Biases today.