Skip to main content

A Brief Introduction to Graph Attention Networks

This article provides a brief overview of the Graph Attention Networks architecture, complete with code examples in PyTorch Geometric and interactive visualizations using W&B.
Created on September 3|Last edited on June 28
In this article, we'll briefly go over the graph attention networks (GAT) architecture proposed in the paper Graph Attention Networks by Petar Veličković, Guillem Cucurull, Arantxa Casanova, Adriana Romero, Pietro Liò, and Yoshua Bengio.
This is a fundamental model from the graph attention networks paradigm (arguably the first paper!) inspired by work in graph convolutional networks and Attention Literature from the Sequence-to-Sequence Paradigm.
There are three main classes of models of graph neural networks, namely message passing graph neural networks, graph convolutional networks, and graph attention networks. For a brief overview of the three paradigms, you can refer to the following blogs:

Here's what we'll be covering:

Table of Contents



Let's get going!

How to Implement Graph Attention Networks

In our article on residual gated graph convolution networks, we looked into the problem of variable length graphs. Models like residual gated graph convolution networks (and their much simpler origins gated graph convolution networks) employ gated recurrent units to allow the model to learn, even with graphs of variable length. Still: we can borrow one more essential idea from the sequence-to-sequence paradigm to improve our models. And yes, the answer is indeed attention!
Simply put: we learn mappings between inputs allowing our models to pay attention to what is needed to be learnt. For a better introduction to attention, you can refer to these articles:



As beautifully summed up by the authors:
"The idea is to compute the hidden representations of each node in the graph, by attending over its neighbours, following a self-attention strategy"
As discussed in the introductory blog, in its simplest form, attention in graph networks calculates the similarity between two node representations using the following formulation:
hv=ϕ(xu,vNuψ(xu,xv))\huge h_v = \phi (\, x_u \, , \oplus_{v \in \mathcal{N}_u}\, \psi(x_u, x_v))

\huge

where ψ\large \psi is the attention mechanism which calculates the attention score.
The attention mechanism described in the paper is as follows:
hu=αuuW1xu+vNuαuvW2xv\huge h_u = \alpha_{uu} W_1 x_u + \displaystyle \sum_{v \in \mathcal{N}_u} \alpha_{uv} W_2 x_v

where the attention coefficients are calculated using:
αuv=exp(LeakyRELU(wT[WxuWxv]))vNuexp(LeakyRELU(wT[WxuWxv]))\huge \alpha_{uv} = \displaystyle \frac{\text{exp}(\text{LeakyRELU}(w^T [Wx_u || Wx_v]))}{ \sum_{v \in \mathcal{N}_u}\text{exp}(\text{LeakyRELU}(w^T [Wx_u || Wx_v]))}


Implementing the Model

As with other models discussed in the series, we go to PyTorch Geometric again for an implementation of the attention mechanism discussed above outlined in the paper (GATConv).
Let's walk through a minimal example implementation:
DROPOUT_RATE: float = 0.5

class GATv1(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, heads):
super().__init__()
self.conv1 = GATConv(in_channels, hidden_channels, heads, dropout=DROPOUT_RATE)
self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1, # For output we only use one head
concat=False, dropout=DROPOUT_RATE)

def forward(self, x, edge_index):
x = F.dropout(x, p=DROPOUT_RATE, training=self.training)
x = self.conv1(x, edge_index).relu()
x = F.dropout(x, p=DROPOUT_RATE, training=self.training)
x = self.conv2(x, edge_index)
return x

Results

We train some models for 50 epochs to perform node classification on the Cora Dataset, using the minimal model implementation as stated above, and report the training loss and accuracy comparing the effect of the hidden dimension on the overall performance.

Run set
3


Summary

In this article, we learned about the Graph Attention Network architecture, along with code and interactive visualizations. To see the full suite of W&B features, please check out this short 5 minutes guide.
If you want more reports covering graph neural networks with code implementations, let us know in the comments below or on our forum ✨!
Check out these other reports on Fully Connected covering other Graph Neural Networks-based topics and ideas.

Iterate on AI agents and models faster. Try Weights & Biases today.