19.Applications of Graph Neural Networks
PinSAGE
PinSAGE is a large scale recommender systems using GNNs. The problem setting is that users interact with items like movies, buy merchandise, listen to music etc and the goal is to recommend items users might like.
Set of users interact with set of items. Based on interactions between these two sets, our goal is to recommend users new items that they might like, that they are not aware of.
For a given query items Q, return a set of similar items that we recommend to the user. We formulate the query Q from user interactions. The query can be a single item or a set of items, and our goal is to return the most similar items based on the whole query.
Having a universal similarity function within items allows for many applications like Homefeed where you have endless feed of recommendations, Related products of a product in a product page, Ads etc.
There are two ways of defining similarity, one is content based and the other is graph-based. Content based similarity is where you only consider the contents of the items like user and items features in the form of images, text, categories and etc. Graph based similarity is based on user-item interactions in the form of a bipartite graph structure. This is also called collaborative filtering. For a given user x, find others who liked similar items, and estimate what x will like based on what similar others will like. Essentially the features of the user will be what he likes and features of an item will be what users like that item.
For this we need to gather known similarity between items. And extrapolate unknown similarities from the known ones. We are mostly interested in knowing what users like but not what users don't like. We also need to have a way of evaluating our recommendation methods.
Pinterest's use case is that they have 300M users, with more than 4B pins, that belong to more than 2B boards. Boards are just a collection of pins that users curated. We can see it as a big bipartite graph between pins and boards.
They have two sources of signal, one is the features of the pin itself like its text and image, and the graph structure, what boards does a given pin belong to and etc. This graph is dynamic. So whatever model we learn needs to apply to new nodes without model retraining.
Pin Embeddings
For the related pins Query, given a pin user interacts with what other pins should he get recommended. We can do this by finding the embedding of the pin and finding nearest neighbours to it in the embedding space.
For this we need to be able to learn embeddings for billions of pins and also perform nearest neighbor query to recommend items in real-time. This is a challenge because if the embedding space is some 300 dimensions and we have pins in the order of billions, we can't really calculate the distances of all the billion pins and then find the closest neighbours. For that we use locality sensitive hashing that allows you to find nearest neighbours in constant time.
GNN
We also use the graph structure using GNNs to get the embeddings of the pin. For every target pin we unroll the bipartite graph and using GNN we get the pin embeddings.
As we are using both the graph structure and the pin features, we are able to borrow information from nearby nodes. Using just the pin features, our image recognizer might think a garden fence looks like a bed railing, but in the graph beds and gardens are not very close.
These pin embeddings are essential for many downstream tasks like recommending related ads, home feed recommendation, cluster users by their interests etc.
PinSage Pipeline
Data collection
We collect billions of training pairs from logs. We collect positive pairs where two pins that are consecutively saved into the same board within a time interval of 1 hour, and negative pairs where we just select two random pair of pins so with high probability we can say the pins are not in the same board.
Train GNN
We will train a GNN to generate embeddings for the pins with the objective that the positive pairs are close by and the negative pairs are far away in the embed space.
Inference
Generate embeddings for all pins.
And use these embeddings for whatever downstream task there is.
Objective Function
Max-margin Loss: L=∑(u,v)∈Dmax(0,−zuTzv+zuTzn+Δ)\displaystyle \mathcal{L} = \sum_{(u, v) \in \mathcal{D}} \max(0, -z_u^Tz_v + z_u^Tz_n + \Delta)
where u,vu, v are positive training pair and nn is a negative sample for u, ziz_i an embedding of node ii, Δ\Delta is the margin deciding how much larger positive pair similarity should be compared to negative sample. We also use cosine similarity while comparing two embed vectors, a simple dot product. We want to get embeddings in such a way that we minimize L\mathcal{L}
Key Innovations in PinSage
On the fly Convolutions
For every pin they do on-the-fly graph convolutions, by sampling the neighbourhood around a node and dynamically construct a computation graph. They never work with the entire graph at a given moment and instead use on-the-fly computational graph for each pin.
Using Random Walks
They also select neighbours of a pin via random walks, so if the selected pin is of high degree, performing aggregation on all neighbours becomes infeasible. We want to be able to select neighbors intelligently that will contribute in the aggregation. For this they can use personalized pagerank. Define importance-based neighbourhoods by simulating random walks and selecting the neighbours with the highest visit counts. This is also called Importance pooling. Using importance pooling choose the nodes with top K visit counts and normalize their counts.
In GraphSAGE we do mean pooling, just average the messages from direct neighbours, but in PinSAGE, we use the normalized counts as weights for weighted mean of messages from the top K nodes. In practice PinSAGE uses K = 50, and there is negligible performance gain for K > 50.
MapReduce inference
Because we are using GNN, and we have to build computational graphs for each node, there is a lot of redundant computation we have to do as similar structures repeat for different pins. There are many repeated computations if we are using localized graph convolution at inference step. Make use of this fact to do less computation using MapReduce.
Obtaining Harder negative samples
When the goal to identify a target pin among 3B pins the issue is that we need to learn with resolution of 100vs3Billion. When the task is to identify the closest 100 pins, instead of using negative samples that we collected randomly we can use harder and harder negative samples. This way we can force the model to learn subtle distinctions between pins. Hard negative samples improve performance.
In PinSAGE they use random walks to get negative samples. Use nodes with visit counts ranked at 1000-5000 as hard negatives. Negative samples that have something in common, but are not too similar.
They start the training with random negative examples, and provide harder negative examples over time. This is called Curriculum training.
PinSAGE experiments
In related Pin recommendations we have two metrics, Hit-rate and Mean Reciprocal Rank.
Hit-rate is the fraction of times the positive example X is among top K closest to Q.
Mean Reciprocal Rank: Take the rank of x for query q, divide it by the total number of pins, and the average of it for all recommended pins.
With 3Billion pins, these are the results, where baseline embeddings are just visual(VGG), just annotation(Word2Vec), and combined embeddings.
Method | Hit-rate | MRR |
---|---|---|
Visual | 17% | 0.23 |
Annotation | 14% | 0.19 |
Combined | 27% | 0.37 |
max-pooling | 39% | 0.37 |
mean-pooling | 41% | 0.51 |
mean-pooling-xent | 29% | 0.35 |
mean-pooling-hard | 46% | 0.56 |
PinSAGE | 67% | 0.59 |
Decagon
How to use GNNs on heterogeneous networks to prevent side effects of different drugs. Where we have multiple types of nodes and edges.
Many patient medications have side effects, that are known when the drug is being discovered and etc. But many patients take multiple drugs and we want to be able to know what side effects this combination might cause, given their individual side effects or no individual side effects. This is called polypharmacy side effect. Knowing this with data and drug trials will be a lot harder. Half of the people with age more than 70 take 5 different drugs, so this is an important question. This is difficult to identify manually because sometimes we have very rare data points and sometimes not observed in clinical testing. Nearly 15% of US population are take multiple meds for complex diseases and coexisting conditions. Systemic experimental screening of drug interactions is challenging.
So Use molecular, pharmacological and patient population data computationally to screen polyphormic side effects. Then patients can have translational strategies for combination treatments in patients.
For now, we will be modeling a system where we try to predict a side effect r likely with a pair of drugs c and d.
We will have two types of nodes, drug nodes and protein nodes. Proteins are molecules in our bodies that interact with each other and drugs to do some biological processes. Biologists have protein protein interaction network of about 20k nodes, specifying how proteins interact with each other. Every protein node has a feature vector of its own. And every drug has connections that connect to proteins, that describe the drug protein interaction. Every drug has its own feature vector that tells us about its properties. And the edges from drug to drug represent the side effects r that we need to identify.
Now our goal given the partially observed graph, predict labeled edges between the drug nodes that represent side effects. Query: Given a drug pair c, d, how likely does an edge (c, r, d) exist that implies the polypharmacy side effect r when a patient takes drug c and d.
For every drug node using GNN we compute messages from each edge type, then aggregate across different edge types.
Below you can see how we compute the node embeddings for the node C using one layer of GNN. This is the drug node encoder.
Now using the drug node embeddings of the query pairs, we pass it through a neural network to predict the possible edge type/side effect. Below image shows given a query pair of drug nodes how we predict edges. This image shows the docoder that does edge prediction using node embeddings
-
Input: Graph over molecules that is the protein-protein interactions and drug target relationships. And graph over populations, that shows side effects of individual drugs or polypharmacy side effects of drug combinations that signify edges between drugs.
-
Setup: Construct a heterogeneous graph of all the data. Train a model to predict known drug-drug association pairs and polypharmacy side effects. And infer new side effects given a drug query pair, and predict candidate polypharmacy side effects.
Results:
Model | AUROC | AUPRC | AP@50 |
---|---|---|---|
Decagon 3-layer | 0.834 | 0.776 | 0.731 |
Decagon 2-layer | 0.809 | 0.762 | 0.713 |
RESCAL | 0.693 | 0.613 | 0.476 |
Node2Vec | 0.725 | 0.708 | 0.643 |
Drug Features | 0.736 | 0.722 | 0.679 |
Decagon is the heterogenous GNN they used and got 54% improvement over baselines. With better models, they used it to get novel drug interaction predictions that were not known. And were able to confirm 5 out of top 10 with later medical research.
Goal Directed Graph Generation (GCPN)
Extension of GraphRNN with an additional goal of generating graphs with a certain property. Like generating new molecules that are valid with high value of a given chemical property learned from realistic examples.
Chapters
- Introduction, Structure of Graphs
- Properties of Networks and Random Graph Models
- Motifs and Structural Roles in Networks
- Community Structure in Networks
- Spectral Clustering
- Message Passing and Node Classification
- Graph Representation Learning
- Graph Neural Networks
- Graph Neural Networks - Pytorch Geometric
- Deep Generative Models for Graphs
- Link Analysis: PageRank
- Network Effects and Cascading Behaviour
- Probabilistic Contagion and Models of Influence
- Influence Maximization in Networks
- Outbreak Detection in Networks
- Network Evolution
- Reasoning over Knowledge Graphs
- Limitations of Graph Neural Networks
- Applications of Graph Neural Networks