A Brief Introduction to the Tensorflow GNN Framework
In this article, we provide a breakdown of the new TensorFlow GNN python package along with code examples and a overview of the API
Created on February 11|Last edited on February 15
Comment
Introduction
While there are plenty of open source libraries for training and building Graph Neural Networks, the most prominent is PyTorch Geometric. Personally, I've been using PyTorch Geometric for the majority of my experiments and paper breakdowns and am fond of their API. However you can't ignore the developer experiences of Keras and Tensorflow. Over the years, attempts have been made to create a framework with a Tensorflow backend such as spektral.
Recently the Tensorflow team released version 1.0 of their own framework for building Graph Neural Networks on the TensorFlow platform. I highly recommend you read the official release notes of version 1.0. We're going to take a high-level look at it today! Here's what we'll cover:
📋 Table of Contents
Introduction🔑 Key Highlights1️⃣ Graph Tensor2️⃣ Data Preparation and Sampling3️⃣ TF-GNN Models4️⃣ TF-GNN Runner👋 Summary
🔑 Key Highlights
We'll look at the key aspects of this framework in detail below:
- A new tfgnn.GraphTensor class which serves as the basic data handling object for operating on graphs
- Utility classes for performing operations on graph tensors and sampling
- A model gym with implementations of various SOTA and fundamental models
- Utility classes for orchestrating training
Let's look into each feature more deeply:
1️⃣ Graph Tensor
A Graph Neural Network is defined by the way it handles graph objects. Jraph by deepmind uses a lightweight NamedTuple (GraphsTuple) which is an ordered collection of graphs in a sparse format while PyTorch Geometric uses custom Dataset classes.
The tfgnn.GraphTensor, is a "composite tensor for heterogeneous directed graphs with features", it follows a structure similar to RaggedTensor or SparseTensor allowing for creating immutable containers which represent several disjoint node sets and edge sets such that each edge set connects a particular pair of node sets: . This is a abstraction which is more general and also allows us to represent simple homogenous graphs ().
Each tfgnn.GraphTensor consists of "NodeSets, EdgeSets and a Context (collectively known as graph pieces), which are also composite tensors. The graph pieces consist of fields, which are tf.Tensors and/or tf.RaggedTensors that store the graph structure (esp. the edges between nodes) and user-defined features."
Let's look at some examples from the official docs:
# A homogeneous scalar graph tensor with 1 graph component, 10 nodes and 3# edges. Edges connect nodes 0 and 3, 5 and 7, 9 and 1. There are no features.tfgnn.GraphTensor.from_pieces(node_sets = {'node': tfgnn.NodeSet.from_fields(sizes=[10], features={})},edge_sets = {'edge': tfgnn.EdgeSet.from_fields(sizes=[3],features={},adjacency=tfgnn.Adjacency.from_indices(source=('node', [0, 5, 9]),target=('node', [3, 7, 1])))})
The following snippet allows for creating a homogenous graph with no graph features. Each of these graphs within the Graph Tensor can consist of 0 or more disjoint (sub-)graphs called graph components. The aforementioned codebase has a single graph component. Let's have a look at another code snippet from the official docs:
# Encodes the following imaginary papers:# [0] K. Kernel, L. Limit: "Anisotropic approximation", 2018.# [1] K. Kernel, L. Limit, M. Minor: "Better bipartite bijection bounds", 2019.# [2] M. Minor, N. Normal: "Convolutional convergence criteria", 2020.# where paper [1] cites [0] and paper [2] cites [0] and [1].#graph = tfgnn.GraphTensor.from_pieces(node_sets={"paper": tfgnn.NodeSet.from_fields(sizes=tf.constant([3]),features={"tokenized_title": tf.ragged.constant([["Anisotropic", "approximation"],["Better", "bipartite", "bijection", "bounds"],["Convolutional", "convergence", "criteria"]]),"embedding": tf.constant( # One-hot encodes the first letter.[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]),"year": tf.constant([2018, 2019, 2020]),}),"author": tfgnn.NodeSet.from_fields(sizes=tf.constant([4]),features={"name": tf.constant(["Kevin Kernel", "Leila Limit", "Max Minor", "Nora Normal"]),})},edge_sets={"cites": tfgnn.EdgeSet.from_fields(sizes=tf.constant([3]),adjacency=tfgnn.Adjacency.from_indices(source=("paper", tf.constant([1, 2, 2])),target=("paper", tf.constant([0, 0, 1])))),"writes": tfgnn.EdgeSet.from_fields(sizes=tf.constant([7]),adjacency=tfgnn.Adjacency.from_indices(source=("author", tf.constant([0, 0, 1, 1, 2, 2, 3])),target=("paper", tf.constant([0, 1, 0, 1, 1, 2, 2]))))})
This graph actually has text features, though a bit verbose it is quite rigid in its definition.
Let's have a go at defining our own tfgnn.GraphTensor object. One of the simplest graphs are cycle graphs , these graphs consist of nodes and have cyclic edges connecting each node. So for a cycle graph simply looks like a rectangle, for it looks like a pentagon. Let's write two functions one to create a jraph.GraphsTuple object and another to create a tfgnn.GraphTensor and then compare.
Run set
0
I'll leave it up to the readers to decide which one seems better ☕️.
2️⃣ Data Preparation and Sampling
Tensorflow GNN also comes with powerful primitives to process tf.train proto messages which encode a graph in a streaming fashion.
The docs mention that the most straightforward method to produce streams of GraphTensor instances to files is to
- Create eager instances of GraphTensor
- Call tensorflow_gnn.write_example()
- serialize the tf.train.Example message to a file
This is a bit complicated and cumbersome to setup initially but the infrastructure to scale and run local jobs is there. This pipeline puts the sampling in the hands of the developer, most pipelines running GNNs in production employ a neighborhood based sampling technique to generate batches from a given large graph i.e. given a large graph with a few target nodes of interest, sample the neighborhood of the various nodes of interest to form batches.
Using TF-GNN one would have to then encode these sampled subgraphs into GraphTensors.
TF-GNN also ships with a powerful graph sampler written in Apache Beam to facilitate local neighborhood learning and convenient batching for graph datasets. It provides a scalable and distributed means to sample even the largest publicly-available graph datasets. Refer to the docs for more details.
3️⃣ TF-GNN Models
TF-GNN also provides implementations of fundamental graph neural network architectures to be used in training pipelines. TF-GNN ships with out of the box support for Message Passing Neural Networks, Graph Attention Networks, Graph Convolution Network and GraphSAGE. We've covered each of these models in the series of Graph Neural Networks already:
4️⃣ TF-GNN Runner
The runner is the highest level of the TF-GNN API pyramid aimed to provide the most minimal code experience to run training jobs using the package. It is a quick-start toolkit with solutions for common graph learning tasks. It includes common graph learning objectives, distributed training capabilities, accelerator support and the handling of many TensorFlow idiosyncrasies.
Let's have a look at the quick start code provided in the official docs:
import tensorflow as tfimport tensorflow_gnn as tfgnnfrom tensorflow_gnn import runnergraph_schema = tfgnn.read_schema("/tmp/graph_schema.pbtxt")gtspec = tfgnn.create_graph_spec_from_schema_pb(graph_schema)# len(train_ds_provider.get_dataset(...)) == 8191.train_ds_provider = runner.TFRecordDatasetProvider(file_pattern="...")# len(valid_ds_provider.get_dataset(...)) == 1634.valid_ds_provider = runner.TFRecordDatasetProvider(file_pattern="...")# Use `embedding` feature as the only node feature.initial_node_states = lambda node_set, node_set_name: node_set["embedding"]map_features = tfgnn.keras.layers.MapFeatures(node_sets_fn=initial_node_states)# Binary classification by the root node.task = runner.RootNodeBinaryClassification("nodes",label_fn=runner.ContextLabelFn("label"))trainer = runner.KerasTrainer(strategy=tf.distribute.TPUStrategy(...),model_dir="...",steps_per_epoch=8191 // 128, # global_batch_size == 128validation_per_epoch=2,validation_steps=1634 // 128) # global_batch_size == 128runner.run(train_ds_provider=train_ds_provider,train_padding=runner.FitOrSkipPadding(gtspec, train_ds_provider),# model_fn is a function: Callable[[tfgnn.GraphTensorSpec], tf.keras.Model].# Where the returned model both takes and returns a scalar `GraphTensor` for# its inputs and outputs.model_fn=model_fn,optimizer_fn=tf.keras.optimizers.Adam,epochs=4,trainer=trainer,task=task,gtspec=gtspec,global_batch_size=128,feature_processors=[map_features],valid_ds_provider=valid_ds_provider)
Not that minimal but considering that the above codebase handles everything from data-reading to batching to modelling and distributed training it's reasonable.
There are four main abstractions to consider here:
- DatasetProvider: The DatasetProvider is an Abstract Base Class that provides a tf.data.Dataset object, this is expected not to be batched and contain serialized tf.Examples of GraphTensor
- Task: The task represents a learning objective for a GNN model and defines all the non-GNN pieces. TF-GNN comes with some pre-written tasks for major use cases such as Graph Classification and Root Node Classification.
- Trainer: The trainer is an Abstract Base Class that is meant to provide any training and validation loops. These may be uses of tf.keras.Model.fit or arbitrary custom training loops.
- GraphTensorProcessorFn: A GraphTensorProcessorFn performs feature processing on the GraphTensor of a dataset. Importantly: all GraphTensorProcessorFn are applied in a tf.data.Dataset.map call.
👋 Summary
In this article we looked a new open source library for modeling and training Graph Neural Networks for a variety of tasks and compared it with existing libraries. We also looked at example code on creating graph tensor objects and briefly went over it's main components.
If you want more reports covering graph neural networks with code implementations, let us know in the comments below or on our community discord ✨!
Check out these other reports on Fully Connected covering other Graph Neural Networks-based topics and ideas.
Introduction to Graph Neural Networks
Interested in Graph Neural Networks and want a roadmap on how to get started? In this article, we'll give a brief outline of the field and share blogs and resources!
A Brief Introduction to Graph Contrastive Learning
This article provides an overview of "Deep Graph Contrastive Representation Learning" and introduces a general formulation for Contrastive Representation Learning on Graphs using W&B for interactive visualizations. It includes code samples for you to follow!
Add a comment
Iterate on AI agents and models faster. Try Weights & Biases today.