Skip to main content

Recommending Amazon Products using Graph Neural Networks in PyTorch Geometric

In this report, we will walk you through how you can leverage PyTorch Geometric along with Weights & Biases to analyse the Amazon products graph and recommend products from the graph.
Created on December 8|Last edited on December 28

Introduction

Shopping online has certain advantages over brick-and-mortar browsing. While you sacrifice being able to touch or try on items, browsing a catalog is far easier. Online retailers excel at this. Not only browsing but recommendations based on your search and behavior with their catalog. The quest is: how exactly do they surface these recommendations?
In this report, we'll show you how. Specifically, we'll walk you through how you can use a database of products on Amazon (along with some additional information) and formulate and visualize the products as a graph in PyTorch Geometric along with Weights & Biases. We'll then use this graph to find products similar to a given product by trying to solve the fundamental Link Prediction problem in Graph Deep Learning using PyTorch Geometric.
Let's get going:

Table of Contents:



Formulating the Problem with Graphs

Building the Graph

There are, of course, several technique to quantify product similarity and make recommendations. But a very interesting way of doing that is by creating a graph.
One way of doing this is to treat each product as a node and having an edge any pair of products that are frequently bought together. This is called a homogenous graph i.e. a graph where each node represents one type of entity and each edge represents one type of relationship.
You can then continue making this more and more complex by adding more information. For example: let's add users as nodes to the graph and add edges between users and the products that they have bought. This would make our graph heterogenous i.e. a graph which can have nodes of multiple types and similarly edges of multiple types which can represent a variety of relationships between the nodes.
If you want to read more about how Data Scientists at Amazon formulated this and scaled it, I would strongly recommend reading this blog post.
💡

What does the data look like?

We obtained a graph of products listed on Amazon which were linked by an edge in case they were bought together frequently. This is openly available thanks to the amazing work in Stanford Network Analysis Project (SNAP). A couple of details about how the data was curated:
Network was collected by crawling Amazon website. It is based on Customers Who Bought This Item Also Bought feature of the Amazon website. If a product i is frequently co-purchased with product j, the graph contains a directed edge from i to j. The data was collected in March 02 2003.
The graph has 262,111 nodes and 1,234,877 edges. The format of the data looked a little something like this \rarr
# Directed graph (each unordered pair of nodes is saved once): Amazon0302.txt
# Amazon product co-purchaisng network from March 02 2003
# Nodes: 262111 Edges: 1234877
# FromNodeId ToNodeId
0 1
0 2
0 3
0 4
0 5
1 0

Loading the data into a PyTorch Geometric Graph

Now we need to convert this into a format which can be processed easily by PyG. We'll walk you through it below:

Run set
29


Visualizing the PyTorch Geometric Graph

It is incredibly hard and resource intensive to visualize hundreds of thousands of nodes so we sampled the first 100 nodes from the graph using the subgraph utility from PyTorch Geometric. The code for that is available below:
import numpy as np
import torch
from torch_geometric import utils
from torch_geometric.data import Data

import wandb

# Download the entire Graph saved in a format suitable for PyG
wandb.init(project='gnn-recommender', job_type='preprocessing')
file_path = wandb.use_artifact("manan-goel/gnn-recommender/amazon_product_graph:latest").download()

graph = torch.load(f'{file_path}/amazon0302.pt')

# Create a mask with the value True for nodes to be retained and False for nodes to be removed
mask = np.zeros(graph.x.shape[0])
mask[:100] = 1
mask = torch.tensor(mask == 1)

# Create and save the new smaller graph by sampling nodes according to the a the mask
g = Data(x=graph.x[mask], edge_index=utils.subgraph(mask, graph.edge_index)[0])
torch.save(g, 'smaller_graph.pt')

# Save the new graph as a W&B artifact
smaller_graph = wandb.Artifact('smaller_graph', type='graph')
smaller_graph.add_file('smaller_graph.pt')
wandb.log_artifact(smaller_graph)

wandb.finish()
Now that we have a smaller graph which will be easier to visualize, we can use PyVis to do that. We use the metadata of the nodes available as a part of the dataset for creating a more informative visualization. You can see how we parsed the metadata in the Appendix.

Run set
1


Creating Node Features

For every node in the graph, we need to provide the model with some information about what the node represents. For example: in an image, every pixel contains 3 values according to the RGB channels, similarly every node must also possess some information. The most basic way of doing this (shown in the previous section) is using the in-degree of the node as a feature. However, using in-degree does not add any information about what the product represents so it may (as we will see in the later sections) lead to bad performance.
It is important to find ways to "featurize" the nodes in the most robust way possible. For example, another method that we have tried is to create Doc2Vec embeddings for the product titles of each node and using that as our input node feature. Another possible option is that in the metadata, we also have a list of categories to which each product belongs. Using an encoding of these categories can also be a good starting point.
One of the fundamental problems that the geometric deep learning community looks at is link prediction i.e. predicting whether an edge exists between two nodes in a graph. This can be predicting whether two users can be friends in a social network or predicting interactions between genes and proteins in a biological network.
The way are going to look at this problem in this report is to find out a way to see if two products are similar and hence, can be recommended as suggestions when someone is looking at one of them or not.
How can we measure the similarity between two nodes? One possible way to do that is using graph embeddings. Graph embedding algorithms learn an embedding space in which neighboring nodes are represented by vectors so that vector similarity measures, such as dot product similarity, or euclidean distance, hold in the embedding space. These can be learned using graph convolutional neural networks (GCNs). You can learn more about GCNs here and here.


Splitting the Dataset

The first thing we need to do is create a train, test and validation split of the edges in the dataset. We start with creating a smaller graph with 20,000 nodes using the same script shown in the previous sections. You can use the following script to randomly split the edges into 3 sections with 5000 edges in the validation and test set each.
import torch
from torch_geometric.transforms import RandomLinkSplit
import wandb

# Download and load the graph from W&B artifacts
wandb.init(project="gnn-recommender", job_type="preprocessing", save_code=True)
wandb.use_artifact("manan-goel/gnn-recommender/smaller_graph:latest")
graph = torch.load('smaller_graph.pt')

# Add 5000 edges in the validation and test sets respectively
transform = RandomLinkSplit(num_val=5000, num_test=5000, is_undirected=True, split_labels=True)
train_data, val_data, test_data = transform(graph)

# Save the splits and save as W&B artifacts
torch.save(train_data, 'train.pt')
torch.save(val_data, 'val.pt')
torch.save(test_data, 'test.pt')

artifact = wandb.Artifact('split_smaller_graph', type='graph')
artifact.add_file('train.pt')
artifact.add_file('val.pt')
artifact.add_file('test.pt')

wandb.log_artifact(artifact)
wandb.finish()

Implementing the Model

As it was mentioned in the previous section, this model will consist of two parts: arriving at a node embedding using graph convolutions for two nodes followed by using the two embeddings to make a final prediction whether a link exists or not.
This implementation is inspired by this amazing blog.
💡

Graph Convolution

There are multiple kinds of graph convolution models available in PyTorch Geometric. For the case of this report we will be using the GraphSAGE model.

Run set
29

For a pair of nodes, the previous module provides an embedding for both of them. This module is responsible for combining the two embeddings and making a binary prediction.

Run set
29


Training the Model

Training a link prediction model brings up a very interesting problem: the dataset we possess is a list of edges in the graph and when you think about it as a binary classification problem, this means we only have positive samples. Hence, there exists a concept called 'negative edges' i.e. edges that do not actually exist in the graph which we consider as negative samples. PyTorch Geometric provides a utility for this as well.

Run set
29

To finally initialize all the modules and train the model, you can use the following snippet
train_graph = torch.load('train.pt')
val_graph = torch.load('val.pt')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
optim_wd = 0
epochs = 300
hidden_dim = 1024
dropout = 0.3
num_layers = 2
lr = 1e-5
node_emb_dim = 1
batch_size = 1024

train_graph = train_graph.to(device)
val_graph = val_graph.to(device)


model = GNNStack(node_emb_dim, hidden_dim, hidden_dim, num_layers, dropout, emb=True).to(device) # the graph neural network that takes all the node embeddings as inputs to message pass and agregate
link_predictor = LinkPredictor(hidden_dim, hidden_dim, 1, num_layers + 1, dropout).to(device)

optimizer = torch.optim.Adam(
list(model.parameters()) + list(link_predictor.parameters()),
lr=lr, weight_decay=optim_wd
)

train_loss = train(
model,
link_predictor,
torch.tensor(train_graph.x).float().to(device),
train_graph.edge_index,
train_graph.pos_edge_label_index.T,
batch_size,
optimizer
)

The loss curve during training looks something like this \rarr

Run set
29


Validating Model Performance

One of the best metrics to validate model performance is measuring Hits@K which is the count of how many positive triples are ranked in the top-n positions against a bunch of synthetic negatives.

Run set
29


Testing

For testing the performance of the model, we took 5,000 nodes from the product graph that had not been seen so far. Initially we took 50,000 nodes from the product graph and for testing we took nodes that were not a part of this set of nodes. To test it, we obtained the learned embeddings of all the nodes in the test set and for 10 of them we saw which other nodes they were connected to using our model.

Run set
29

Out of all the connections, we plotted the ones with a confidence score higher than 0.9 in the figure shown above.
Different edges in the the graph have different sizes. The higher the confidence score, the thicker the edge. The code for this is available in the appendix.
💡

Conclusion

In this report we walked through, how you can parse your own graph data into a PyTorch Geometric dataset, visualize that data with PyVis and W&B. We also looked at how different utilities in PyG help in performing different operations on graphs like splitting the graph, sampling negative edges and more with a lot of ease.
Finally we also implemented a link prediction model with GraphSAGE for graph convolution to recommend potential products on amazon while you're looking at one product. PyG is an incredibly powerful and easy to use toolkit for all your geometric deep learning needs!

Appendix

Parsing Metadata

Creating Visualization for the Test Set

Iterate on AI agents and models faster. Try Weights & Biases today.