Skip to main content

Multi-Task Self Supervised Graph Representation Learning

Brief breakdown of Multi-task Self-supervised Graph Neural Network Enable Stronger Task Generalization [ICLR 2023] by Mingxuan Ju, Tong Zhao, Qianlong Wen, Wenhao Yu, Neil Shah, Yanfang Ye and Chuxu Zhang
Created on February 5|Last edited on February 7
NOTE: This Report is a part of a series of reports on Graph Representation Learning, for a brief overview and survey please refer to the following articles as well
💡



Over the past few articles we have introduced a basic framework for Contrastive Learning on Graphs (GRACE), extended it to use a projection head (GraphCL), adopted it to use smarter augmentations and graph diffusion (MVGRL), but we still we have been limited by a single task. Modern AI is now building it's backbones on multi-modal multi-task models which are capable of learning across modalities and tasks, in this article we'll cover one such framework which aims to enable multi-task self-supervised graph representation learning as covered in "Multi-task Self-supervised Graph Neural Network Enable Stronger Task Generalization".
We assume a basic understanding of Graph Neural Networks, if you feel like a quick refresher please refer to the following article which provides links to other great resources to read and learn more !!


Table of Contents





👨‍🏫 Method

Figure 1: Proposed ParetoGNN Framework.
The prior methods discussed in the context of Self-Supervised Learning on Graphs have mostly been adapted from older works from the Self-Supervised Vision domain. These architectures show clear similarities to techniques from the vision domain such as SimCLR and VICReg and therefore are constrained by using only one pretext task with a single philosophy such as mutual information maximization in MVGRL. While this approach works great in the obvious scenarios i.e. the tasks which naturally can be formulated using the aforementioned pretext task but fails in performance for other downstream methods. If we want to learn rich representations we can't just train another model for each downstream task, that won't scale. The LLM literature has taught us we can learn one model that has competitive performance for most downstream tasks, the framework proposed in this paper tries to combine multiple philosophies to enhance task generalization for SSL-based GNNs.
Other attempts have been made which try to solve this problem but rely on a node-level pseudo-homophily assumption. Instead in this paper the authors propose to train a single graph encoder, all pretext tasks are simultaneously optimized and dynamically coordinated. They do this by reconciling pretext tasks by dynamically assigning weights that promote the Pareto optimality, such that the graph encoder actively learns knowledge from every pretext task while minimizing conflicts.
The ParetoGNN framework can be summarised as follows:
  • Given a full graph G\large \mathcal{G} as a data source, for each pretext-task we sample sub-graphs from G\large \mathcal{G} using task specific augmentations TK()\large T_K(\cdot). Sub-graph sampling is a natural way of augmentations which increases the diversity of training data and also serves as a memory efficient way to do multi-task training.
  • We then get task-specific node representations by passing each perturbed graph through one-shared graph encoder parameterised by θg\large \theta_g
  • For every pretext-task we define Lk\large \mathcal{L}_k to be the loss for the k-th task, defined as Lk(G;Tk,θg,θj)\large \mathcal{L}_k(\mathcal{G}; T_k, \theta_g, \theta_j) where θj\large \theta_j is some network to decoder node representations, maybe a MLP Projection Head or a Graph decoder.
  • We then attempt to minimise the loss simultaneously across all pretext tasks using the Pareto Optimization Formulation (essentially a Multiple-gradient descent algorithm).
But which pre-text task to use ? The authors designed five simple pretext tasks spanning three high-level philosophies, including generative reconstruction, whitening decorrelation, and mutual information maximization.
  • Generative Reconstruction involves masking the features of a random batch of nodes, processing the masked graph through the GNN encoder, and reconstructing the masked node features given the node representations of their local sub-graphs. A similar technique is applied to train for reconstruction on links between the connected nodes to retain the pair-wise topological knowledge. Referred to as FeatRec and TopoRec.
  • Whitening Decorrelation involves independently augmenting the same sub-graph into two views, and then minimizing the distance between the same nodes in the two views while enforcing the feature-wise covariance of all nodes equal to the identity matrix. Referred to RepDecor.
  • Mutual Information Maximization involves maximizing the local-global mutual information by minimizing the distance between the graph-level representation of the intact sub-graph and its node representations, while maximizing the distance between the former and the corrupted node representations. The same is applied to views of the sub-graphs as well. Referred to as MI-NG and MI-NSG.
Let's expand on this Multiple Loss Optimization problem. The loss minimization formulation that we are looking at, can be formalised as:
minθg;θ1...θk=k=1KαkLk(G;Tk,θg,θk)\huge \displaystyle \min_{\theta_g; \theta_1 ... \theta_k} = \sum_{k=1}^{K} \alpha_k \cdot \mathcal{L}_k (\mathcal{G} ; T_k , \theta_g, \theta_k)

Now I know that might look scary but it's simply a weighted sum ! αk\large \alpha_k is the weight associated to some task k\large k and Lk\large \mathcal{L}_k is the loss associated with that task. Something to note here is that this is a single objective formulation, i.e. we are trying to optimise the weighted sum and not the individual losses. This leads to undesirable behaviours such as different gradient scales for each loss. Hence, the authors try to tackle this by formulating this as a multi-objective optimization problem, i.e. we try to optimise each loss simultaneously together.
To keep the article relatively light, we are not discussing the theorems behind a Pareto Optimal Formulation of the Multiple Gradient Descent Algorithm and the troubles involved in adapting it a multi-task self-supervised setting. However, if you want to learn more please refer to the original paper on Multiple Gradient Descent Algorithm and Section 2.3 of the paper.
💡

👨‍💻 Code

The ParetoGNN framework follows from the previous articles discussed in the series and builds on top these frameworks, with the key difference lying in the encoder being used. This contains the various augmentations that the graph encoder performs to generate various views.
Let's look into the implementation of the augmentations being used, maximising node-graph information between various views.
class GraphEncoder(torch.nn.Module):
...
def p_ming(self, graph, feat, cor_feat):
positive = self.big_model(graph, feat)
negative = self.big_model(graph, cor_feat)

summary = torch.sigmoid(positive.mean(dim=0))
positive = self.discriminator(positive, summary)
negative = self.discriminator(negative, summary)
l1 = F.binary_cross_entropy(torch.sigmoid(positive), torch.ones_like(positive))
l2 = F.binary_cross_entropy(torch.sigmoid(negative), torch.zeros_like(negative))
return l1 + l2
For the full code please refer to the official implementation, while Multi-Task is a difficult task I am working on a simple abstracted implementation and plan to release a Colab soon ! Stay tuned to this article to read that implementation soon.
I preprocessed and made the Heterophilous datasets available online (Actor, Squirrel, Chameleon).

🔗 Summary

In this article we briefly went over a Multi-Task Self-Supervised Graph Representation Learning framework as introduced in the paper "Multi-task Self-supervised Graph Neural Network Enable Stronger Task Generalization" by authors Mingxuan Ju, Tong Zhao, Qianlong Wen, Wenhao Yu, Neil Shah, Yanfang Ye and Chuxu Zhang. While a complicated objective it builds on top of prior work in a minimal fashion and appears to be a sturdy and scalable framework.
To see the full suite of W&B features, please check out this short 5-minute guide. If you want more reports covering the math and "from-scratch" code implementations, let us know in the comments down below or on our forum ✨!
Check out these other reports on Fully Connected covering other Geometric Deep Learning topics such as Graph Attention Networks.