Skip to main content

Citation Networks With PyTorch Geometric and Weights & Biases

In this article, we'll provide a Primer on how to use GNNs and PyTorch Geometric to perform deep learning on linked data — specifically on Citation Networks.
Created on February 8|Last edited on November 29

Introduction

Whether you're an experienced model-builder or are just starting on your model-building journey, you may have run into particular classes of problems that were challenging or too costly to represent using our standard set of deep learning models: convolutional neural networks (CNNs), recurrent neural networks (RNNs), long short-term memory (LSTMs), etc.
When working with spatial or sequential data, the aforementioned model types often work well. However, we often encounter machine learning tasks that necessitate our model developing an understanding of both local and global relationships.
A real-world example needing to model both local and global phenomena is illustrated by how three entities, each purchasing a single home in a neighborhood, drive up the price of real estate across an entire region.
Using three "comps" (short for comparables or comparable price, so the price for which a home was sold), the next set of homes for sale are priced based on the three previous purchase prices of three homes. Then, as those neighborhood-level home prices increase or decrease, the city-wide, county-wide, and regional prices increase or decrease.
Ultimately, it's possible to nudge upward or down an entire region's home prices in the United States three home sales at a time. If we attempted to model that 'price cascade' on a neighborhood-wide, city-wide, county-wide, state-wide, and region-wide level using traditional models, we would find it difficult to capture the local and global interactions.
However, by modeling our home-buying data as a series of nodes (homes) and edges (price influence), we can see how local pricing of only a few homes in a neighborhood can cascade into very real pricing shifts at the state or even regional levels. To model that relationship in a cost-effective way that allows us to capture both local and global phenomena, we use graph data structures. To perform machine learning on graphs, we build graph neural networks (GNNs) consisting of nodes and 'things linking nodes', which we call edges.

In this article — and in the accompanying Colaboratory Notebook — we will show you how to use GNNs and PyTorch Geometric to perform deep learning on another kind of linked data: a collection of machine learning papers linked by their co-occurring words. You'll learn how to use a GNN to predict the linkages between papers based on the tokens (words) that appear in each paper.

Table of Contents




Background to Graph Neural Networks (GNNs)

As mentioned in the introduction, Graph Neural Networks (GNNs) are well-suited to graph-based tasks like node and graph classification, node and graph clustering (grouping similar kinds of nodes together), link prediction (which estimates the probability of links, called edges, between nodes in a graph), etc.
In its most basic form, we use link prediction to predict which parts of a static graph are joined together; dynamic link prediction, on the other hand, allows us to model non-static/dynamic networks, which more closely approximate real-world phenomena such as community detection, the propagation of a message or a contagion through a web of linked entities, etc.
Graphs can also be divided into synchronous graphs, in which messages are delivered within one unit of time. The nodes have access to a common clock or asynchronous graphs in which arbitrary message delays occur. There is no common clock available to the nodes.
Diving more deeply into how these games of 'telephone' work, GNNs rely on the basic idea of propagating information between connected nodes in a graph. GNNs make use of techniques such as message passing and diffusion to pass information between nodes. Message passing is a technique in which the information is propagated between neighboring nodes in a graph.
Diffusion is a technique used to propagate information across more than just a one-hop distance, as is the case when working with graph convolutions and message passing. Indeed, diffusion as a recent development in graph neural networks - when we replace message passing with graph diffusion convolutions - consistently leads to significant performance improvements across a wide range of models on both supervised and unsupervised tasks and a variety of datasets.
GNNs can be used to classify networks, identify communities, detect anomalies, and predict future relationships between nodes. GNNs can also be used to predict the future state of a graph or the behavior of a graph over time. As stated in the introduction, this technique - called dynamic link prediction - can be used for forecasting the spread of a virus or for predicting the spread of an opinion or idea throughout a network of interconnected beings. Thinking more literally about the term network, we note that GNNs can also be used to identify suspicious or malicious activities, computer networks, and actors.

The Cora Graph: Modeling Papers as Term-Frequency Matrices

The Cora dataset sourced from PyTorch Geometric is originally from the “Automating the Construction of Internet Portals with Machine Learning” paper from researchers at Carnegie Mellon and MIT. The Cora dataset is a paper citation network data that consists of 2,708 scientific publications. Each node in the graph represents each publication, and a pair of nodes is connected with an edge if one paper cites the other.
Papers come from seven different domains, which make up the seven classes in the dataset: Theory, Reinforcement Learning, Genetic Algorithms, Neural Networks, Probabilistic Methods, Case-Based, and Rule Learning. The citation network consists of 5429 links which are linkages between the publications.

Run set
11

Each publication in the dataset is described by a 0/1-valued word vector indicating the absence or presence of the corresponding word from the dictionary. The dictionary consists of 1433 unique words; this dictionary was derived by stemming the corpus, taking counts of tokens, and discarding token tallies under ten. While rudimentary, this process was standard for the time – twenty years ago – for natural language process (NLP) pre-processing workflows.
We can think of this co-occurrence matrix, shown below, as a way of illuminating the 'semantic proximity' of one paper to another: if our paper on lobsters shares a great deal of words with a paper on starfish but shares far fewer words with a paper on sculptures we can assume that the lobsters<>starfish papers are more closely related than the lobsters<>sculptures papers.
Below we've taken the Cora dataset's term frequency matrix - a tally of which words occur in which paper - and have rendered that using the Table functionality in the wandb library.

Co-Occurrence Matrices: Modeling Tokens (Words) in a Paper

Given some documents, D, and some tokens, T, let w for each D and T denote the presence or absence of that word/token in the document:
[T1T2TtD1w11w21wt1D2w12w22wt2Dnw1nw2nwtn]\begin{bmatrix} & T_1 & T_2 & \cdots & T_t \\ D_1 & w_{11} & w_{21} & \cdots & w_{t1} \\ D_2 & w_{12} & w_{22} & \cdots & w_{t2} \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ D_n & w_{1n} & w_{2n} & \cdots & w_{tn} \end{bmatrix}

The Wikipedia page on document-term matrices is an excellent refresher on the topic, why we use that representation of 'word tallies in a document,' and more. Below is a 'poor man's document term matrix', which we've constrained to the top twenty observations - the papers with the twenty 'most non-zero word values' – instead of the 2,708 observations in the original dataset.
This data is quite sparse, so by curating the top twenty records, the document-term matrix has at least a few non-zero values. Recall that there are w_0 through w_1433 in this term matrix as the Cora dataset creators whittled the vocabulary of all 2,708 papers down into 1,433 unique tokens (words):

Run: lemon-elevator-1
1

If we were approaching this same task whereby we attempt to understand citation networks and shared semantic attributes of papers – co-occurring words, for example – in the twenty-first century, we would probably move beyond the simple term-frequency matrix shown above.
As mentioned earlier, the term-frequency matrix for this Cora dataset was generated by tallying up the co-occurrence of each word in a paper and discarding any word mentioned less than ten times. For more robust, less-synthetic datasets, we encourage you to look at the Open Graph Benchmark datasets from Stanford University.

Making Predictions using the Cora Citation Network

In the chart below, you see the loss curve from our training experiment (which you can find in this Colab Notebook). Note that the spikes you see are from resumed/restarted training processes. Weights & Biases makes it easy to resume crashed or otherwise interrupted model training processes so you don't have to throw away precious minutes or hours of GPU or TPU compute time.


Moving on to these charts, you'll see the evolution of the embeddings over time as the GCN is trained. On the right-hand side, that is the final loss curve for our experiment. We experimented with assorted hyperparameters – varied learning rates, etc. – which resulted in slightly more or less performant models.

Run set
10


Conclusion

With an understanding of how to use GNNs and PyTorch Geometric, we encourage you to experiment with various architectures, hyperparameters, and datasets while working on authorship and citation graph tasks like the Cora dataset.
If you're new to the domain of graph theory and/or graph neural networks, we've curated a short list of conferences, papers, textbooks, and tutorials to help get you started:
Happy reading!
Iterate on AI agents and models faster. Try Weights & Biases today.