Skip to main content

GraphCL: Graph Contrastive Learning Framework with Augmentations

Graph Contrastive Learning Framework as outlined in "Graph Contrastive Learning with Augmentations" by You. et al.
Created on February 2|Last edited on February 5
NOTE: This Report is a part of a series of reports on Graph Representation Learning, for a brief overview and survey please refer to the following articles as well
💡


Introduction

Self-Supervised Learning is a special form of Unsupervised Learning where supervision is provided by the data itself. However rather than using labels denoting classes or other form of target information, if we want just to want to learn representations perhaps to act as a prior distribution for later fine-tuning we can mask out a part of the data and then train the model to recreate the missing information.
One of the key pillars of Self Supervised Learning is Contrastive methods wherein we distort or perturb the initial data by performing augmentations (semantics preserving) and then train a model to recognise them. The Contrastive bit comes in when we further probe the model to group views of similar data points together and away from views of different data points. This is the crux behind Contrastive Learning. This simple notion of training the model to group views of similar classes leads to profound results and has lead to great techniques such as SimCLR.
In this article we will cover a simple contrastive framework for Graph Representation Learning called GraphCL as outlined in the paper "Graph Contrastive Learning with Augmentations" by You. et al.
NOTE: We assume a basic understanding of Graph Neural Networks, if you need a quick refresher the following article is recommended.
💡


Table of Contents





👨‍🏫 Method

Figure 1: GraphCL Framework
Similar to the GRACE Framework as introduced in a related article A Brief Introduction to Graph Contrastive Learning, the GraphCL framework also follows best practices from Self Supervised techniques as explored with other modalities viz. a shared encoder and a projection head. The most similar framework resembling GraphCL is SimCLR which has been previously explored before.

The GraphCL framework can be summarised as follows:
  • Given a Graph G\large \mathcal{G} we generate two views Gi^,Gj^\large \hat{\mathcal{G}_i}, \hat{\mathcal{G}_j} by performing augmentations. The authors selectively learn these augmentations based on graph domains.
  • These two views Gi^,Gj^\large \hat{\mathcal{G}_i}, \hat{\mathcal{G}_j} are then passed through a graph encoder f()\large f(\cdot) leading to representations hi,hj\large h_i, h_j. These graph encoders can be any architecture.
  • These representations are then passed through a projection head g()\large g(\cdot) a simple MLP network which generates two views zi,zj\large z_i, z_j.
  • We then apply a contrastive objective L()\large \mathcal{L} (\cdot) between the two views, the objective in this case is the normalized temperature-scaled cross entropy loss.
NOTE: As it is considered best practice in Self Supervised Learning, we don't explicitly sample negative pairs instead the augmented views of the other graphs in a batch become the negative pairs.
💡
The four graph augmentations studied in this paper are:
  • Node Dropping: From any given graph G\large \mathcal{G}, we randomly "drop" some nodes along with their edges.
  • Edge Perturbation: This involves perturbing the edges in G\large \mathcal{G} through randomly adding or dropping a certain ratio of edges.
  • Attribute masking: Attribute masking prompts the model to recover masked node attributes using their context information, i.e., the remaining attributes.
  • Subgraph Generation: This involves creating a subgraph from the original graph by performing random walks.
The authors stress on the importance of data augmentations and deem them crucial for graph contrastive learning.


The overall framework can be generalised as follows:
l=EPGi^{EP(Gj^Gi^)T(f1(Gi^),f2(Gj^))+log(EPGj^eT(f1(Gi^),f2(Gj^)))}\huge l = \mathbb{E}_{\mathbb{P}_{\hat{\mathcal{G}_i}}} \{ - \mathbb{E}_{\mathbb{P}_{(\hat{\mathcal{G}_j} | \hat{\mathcal{G}_i})}} T (f_1(\hat{\mathcal{G}_i}), f_2(\hat{\mathcal{G}_j})) \, + \, \log(\mathbb{E}_{\mathbb{P}_{\hat{\mathcal{G}_j}}} e^{T (f_1(\hat{\mathcal{G}_i}), f_2(\hat{\mathcal{G}_j}))}) \}

where T()\large T (\cdot) is some arbitrary learnable score function usually parameterized with the similarity function sim(,)\large \text{sim}(\cdot, \cdot).

Code

Let's look into the code in a abstract manner implemented using PyTorch + PyTorch Geometric. The authors have made the original code available.
class GraphCL(torch.nn.Module):
...
def train_step(
self,
augmented_views: List[torch.Tensor],
) -> torch.Tensor:
"""
Perform a single training step.

Args:
augmented_views (List[torch.Tensor]): Views generated by performing augmentations

Returns:
float: Loss.
"""
# Generate Graph Views

## Generating representations : intermediate_reps: List[torch.Tensor]
intermediate_reps = self.encode(augmented_views)

## Generate views : reps: List[torch.Tensor]
reps = self.projection_head(intermediate_reps)

## Calculate Loss : loss: torch.Tensor
loss = self.contrastive_loss(reps)

return loss

Summary

In this article we briefly went over the paper titled "Graph Contrastive Learning with Augmentations" by Yuning You, Tianlong Chen, Yongduo Sui, Ting Chen, Zhangyang Wang and Yang Shen and contrasted (pun not intended) it with a related paper discussed earlier in A Brief Introduction to Graph Contrastive Learning. We also went over the implementation in a abstract manner and observed performance metrics with the power of Weights & Biases Logging.
To see the full suite of W&B features, please check out this short 5-minute guide. If you want more reports covering the math and "from-scratch" code implementations, let us know in the comments down below or on our forum ✨!
Check out these other reports on Fully Connected covering other Geometric Deep Learning topics such as Graph Attention Networks.