Graph Representation Learning

Graph Representation Learning is the task of effectively summarizing the structure of a graph in a low dimensional embedding. With the rise of deep learning, researchers have come up with various architectures that involve the use of neural networks for graph representation learning. We call such architectures Graph Neural Networks.

Experiment with Graph Neural Networks in this Google Colab Notebook →

Why do we Need Graph Neural Networks?

Before getting into the specifics of GNNs, we first understand the motivation behind doing so. One might ask that can we not use our current deep learning architectures like CNN and RNN for graph representational learning. However, it can be observed that CNNs specialize in grid-like structure whereas RNNs specialize in sequential data. In contrast, graphs are more unstructured and would benefit from special methods that can learn in such a setting.

Graphs are everywhere!
Various data stores have an inherent graphical structure and leveraging this data can have a lot of impact. Following are various domains where GNNs could have an impact:

These networks can be leveraged for better recommendation systems, understanding the economic impact of various components, drug discovery, understanding how neurons in the brain function, and many other applications.

Problem Statements

Node Classification

Classifying unknown nodes based on other known nodes. For example, classifying social media accounts as bots or humans.

Graph Classification

It involves classifying an entire graph into predefined categories. For example, predicting whether a molecule is soluble in water based on its atomic structure.

Link Prediction

Determining whether a particular node in the graph is linked with some other node. For example, it can be used in social networks for predicting whether a person knows another person and suggesting friends.

Graph Regression

The graph regression task is analogous to Logistic regression and logistic classification. It has a different loss function and performance metric that focuses on regressing a particular variable against the available graph structure rather than classification.

Generating Node Embeddings

The task of learning representations from the graph structure involves creating low dimensional embedding for nodes of the graph. The goal is that the nodes which are similar should be closer in the low-dimensional Euclidean space. This task is depicted in the image below.

Node embedding.png

-> Graph Embedding — Representation Learning on Networks <-

We talk about two basic approaches for node embedding. (1) Classical "shallow" methods, and (2) Neural-Network based methods.

Shallow Embedding Methods

The typical characteristic of shallow embedding methods is that the encoder grows in size linearly with the number of nodes in the graph. The retrieval of node embedding is essentially a table lookup. Hence, the name "shallow".

Considering the simplest shallow encoding scheme as shown in the image below:


-> Shallow Embedder <-

The task of learning shallow node embeddings typically involves three steps:

1. Define an encoder that maps nodes to embedding.

  ENC(v) = Z.v; 

where $Z \epsilon R^{d \times |\nu|}, |\nu|$ is the number of nodes in the graph, $v$ is the one-hot encoded vector for any node in the graph.

2. Define a function that measures the similarity between two nodes, i.e, $similarity(u, v)$.

3. Optimize the parameters of the encoder such that $similarity(u, v) \approx z_v^T.z_u$

There are other such shallow encoding based approaches such as Node2Vec, DeepWalk, and LINE. However, there are some inherent problems with such approaches:

Graph Neural Networks

As seen above, shallow embedding methods have certain limitations that impact their ability to perform in real-life scenarios. To remediate these issues we take a look at GNNs. There are broadly two types of neural network-based node embedding approaches :

1. Message Passing Based GCNs

2. Weisfeiler-Lehman GNNs based on the WL test[1]

In this report, we shall only be discussing message-passing based GCNs since current WLGNN based approaches are not scalable. Their complexity increases polynomially in space and time with respect to graph size which generally makes the use of WLGNNs intractable for large datasets.

Message Passing Based Graph Convolutional Networks

Message Passing based GCNs rely on a concept called neighborhood aggregation. Basically, the embedding for each node would be an average of all of its neighbors passed through a neural network. If you’ve been following correctly, you might ask that it is still simply aggregating the neighbors. However, this is an iterative procedure, thus it can be imagined that the aggregation after a few iterations would be compounded to cover the entire graph. This process can be visualized in the image below.


-> Shared GNN Encoder <-

Formally, this process can be summarized as the following steps:

1. Define a neighborhood aggregation function
The simplest case is where each node considers all its neighbors equally, that is, all the neighbors are equally weighted and have equal impact on the generated node embedding. Such architectures are also called isotropic. Eg GCNs[2] and GraphSage[3]. The equation below shows the node update operation for isotropic architectures.
$h^{l+1}i = \sigma(W_1^l h_i^l + \Sigma{j\epsilon N_i} W_2^l h_j^l ), \quad h^l, h^{l+1}\epsilon \R^{n\times d}, W^l_{1,2}\epsilon \R^{d\times d}$
$\sigma$ is an activation function like ReLU
The more advanced architectures are anisotropic, that is, each neighbor has a different amount of impact on the generated node embedding. Eg. MoNet[4], GAT[5], GatedGCN[6]. The following equation shows the node update operation for anisotropic architectures.
$h^{l+1}i = \sigma(W_1^l h_i^l + \Sigma{j\epsilon N_i} \eta_{ij} W_2^l h_j^l ), \quad h^l, h^{l+1}\epsilon \R^{n\times d}, W^l_{1,2}\epsilon \R^{d\times d}$
$where, \eta_{ij} = f^l(h^l_i, h^l_j)$ and $f^l$ is a parameterized function whose weights are learned during training
Observe that $\eta_{ij}$ here contributes as the attention mechanism.

2. Define the loss function
The loss function is defined based on the task for which predictions are to be made. For example, for a node classification task, the loss function would be the BCE loss in predicting the labels. Thus, the learned node embeddings would be effective at extracting features that help determine the class of the node.

3. Train
Based on the defined loss function, we apply gradient descent to update the weights of the parameters shared across the network. Since each update step only depends on the local neighbors, this operation is highly scalable and allows for batch operations.

Now that we have an understanding of message passing based GCN, in the following section we shall look at Gated Graph Convolutional Network which is one of the best performing architectures of this type.

Gated Graph Convolutional Network

The GatedGCN architecture is an anisotropic message-passing based GNN that employs residual connections, batch normalization, and edge gates. The given figure summarizes each layer of the GatedGCN network.


-> A layer of GatedGCN [7] <-

$h^{l+1}i = h^l_i + ReLU(BN(U^l h_i^l + \Sigma{j\epsilon N_i} e^l_{ij}\odot V^l h^l_j )),$
where $U^l, V^l \epsilon \R^{d\times d}$, $\odot$ implies element-wise multiplication, $N_i$ are neighbors of node $i$, $e^l_{ij}$ are the edge gates defined as follows:

$e^l_{ij} = \sigma(\hat{e}^{l}{ij})\div (\Sigma{j'\epsilon N_i} \sigma(\hat{e}^{l}_{ij'}) + \varepsilon)$

$\hat{e}^{l}{ij} = \hat{e}^{l-1}{ij} + ReLU(BN(A^l h^{l-1}_i + B^l h^{l-1}j + C^l \hat{e}^{l-1}{ij}))$
where $\sigma$ is the signmoid function, $\varepsilon$ is a small fixed constant for numerical stability, $A^l, B^l, C^l \epsilon \R^{d \times d}$

Now that's a lot of equations, lets look at the salient features exhibited by these equations.


Now that we looked at the features of GatedGCN, in the following section we actually experiment with it.


We train our network on the PATTERN dataset proposed by Dwivedi et al for binary node classification. It is an artificial dataset generated using Stochastic Block Models(SBM). SBMs are known to controllably model communities in social networks. Our train/validation/test split is 10K/2K/2K graphs with an average of 117 nodes and 4749 edges.

The experiments are carried out in PyTorch. The Colab notebook is available at this link. We use the Adam optimizer along with the ReduceLROnPlateau scheduler. The initial learning rate was 10-3 and the reducing factor was 0.5. The allotted GPU by Colab at the time of the experiment was Nvidia T4. The model converges in about 50 epochs and each epoch takes about 280s.

As seen in the Validation Accuracy vs Epoch graph, the model converges with a validation accuracy of 84.37%.

Finally, the classification accuracy on the test split is 84.5%.

Full code →

Section 2


In this report, we understand the motivation behind graph representation learning and the need for GNNs. We also broadly look at the GNN based approaches that are currently being researched and gain an intuitive understanding of the methods. Finally, we understand the internal working of the GatedGCN which is one of the best performing GNN architecture.

While this report serves as an overview of the methods being used for graphical representation learning, we have only scratched the surface in terms of the depth of the literature currently available. I would like to encourage the readers to look at the paper Benchmarking Graph Neural Networks by Dwivedi et al. The paper dives into more details of the current GNN architectures along with a standardized comparison of all the methods. They have also published their code on Github. The code that has been used for this experiment is a simplified version of their code. As an exercise, readers could train and compare all the architectures using Weights and Biases Reports.

Checkout Part 2 of this post where I compare message passing based GNN architectures.


  1. Boris Weisfeiler and Andrei A Lehman. A reduction of a graph to a canonical form and an algebra arising during this reduction. Nauchno-Technicheskaya Informatsia, 2(9):12–16, 1968.

  2. Thomas N. Kipf and Max Welling. Semi-supervised classification with graph convolutional networks. In International Conference on Learning Representations (ICLR), 2017

  3. Will Hamilton, Zhitao Ying, and Jure Leskovec. Inductive representation learning on large graphs. In Advances in Neural Information Processing Systems, pages 1024–1034, 2017.

  4. Federico Monti, Davide Boscaini, Jonathan Masci, Emanuele Rodola, Jan Svoboda, and Michael M. Bronstein. Geometric deep learning on graphs and manifolds using mixture model cnns. 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Jul 2017.

  5. Petar Velickovi ˇ c, Guillem Cucurull, Arantxa Casanova, Adriana Romero, Pietro Liò, and Yoshua ´ Bengio. Graph Attention Networks. International Conference on Learning Representations, 2018

  6. Xavier Bresson and Thomas Laurent. Residual gated graph convnets. arXiv preprint arXiv:1711.07553, 2017

  7. Vijay Dwivedi, Chaitanya Joshi, Thomas Laurent, Yoshua Bengio, Xavier Bresson. Benchmarking Graph Neural Networks. arXiv preprint arXiv:2003.00982