Skip to main content

De Novo Molecule Generation with GCPNs using TorchDrug

How reinforcement learning, specifically graph convolutional policy networks, can help create brand new molecules to treat real world diseases
Created on January 25|Last edited on August 4
The generation of new molecules has become one of the most popular methods for finding treatments of novel diseases and, unsurprisingly, we are seeing a lot of molecule generation done using machine learning. Of course, with molecules, there are hard and fast rules: drug molecules have to possess certain properties to exist in the real world.
This is where reinforcement learning comes into the picture.
Here, it's important to highlight graph convolutional policy networks (GCPNs). First proposed in 2018, GCPNs were a cornerstone in the application of graph neural networks and reinforcement learning to the drug discovery problem. Using RL, we can optimize the molecular graph generation process to add specific substructures that fulfill our requirements.
Then came along the good folks at TorchDrug. They created a one-stop shop for implementations of a ton of graph based models for a very large variety of problems. The one we're concerned with today is molecule generation with GCPNs.
In this report, we'll walk you through molecular graphs, how graph convolutional policy networks work, and finally train a model of our own using TorchDrug with the logging facilitated by Weights & Biases. Let's get started:

A Quick Note About Molecular Graphs

First, in case you need a quick refresher on molecular biology: molecules can in fact be represented as heterogeneous graphs. Here, atoms are equivalent to nodes and bonds are analogous to edges. The graph can consist of dd  types of nodes and bb types of edges. A graph with nn nodes can be represented using a set of vertices F{0,1}n×dF \in \{0, 1\}^{n \times d} and an adjacency tensor E{0,1}n×n×bE \in \{0, 1\}^{n \times n \times b}.
Source: Improving graphs of cycles approach to structural similarity of molecules. PLoS ONE: 14(12)

Graph Convolutional Policy Network

Graph Generation

The graph generation process is modeled as a sequence of graph edits: selecting two atoms and constructing an edge between them. This seems pretty straightforward but the question we're concerned with is which pair of atoms do we choose? And since we have multiple types of edges, which one should we use?
That's what we want the model to learn! Hence, the GCPN model uses graph convolutions at every step to calculate four probabilities one after another \rightarrow
  1. Probability of picking each atom from the graph constructed so far.
  2. Probability of picking another atom from either the graph constructed so far or from the dd atoms we have in our vocabulary of atom types using the atom picked in the previous step.
  3. Probability of using each edge type to connect the atoms picked in the previous two steps.
  4. Probability of stopping the generation process.
At each step, embeddings are calculated for each node in the graph using graph convolutional networks (GCNs) to capture information from its neighbours.
Source: Graph Convolutional Policy Network for Goal-Directed Molecular Graph Generation. NeurIPS 2018: 6412-6422


Rewards

There are intermediate rewards at every step if the action leads to a molecule that does not violate any valency rules and there is a final reward when the generation terminates depending on the property for which the model is being optimized. In this report, the property we will be looking at is the penalized LogP (pLogP) which is a measure of \rightarrow
  • How easy/hard it would be to transport the drug inside the body
  • How easy/hard it would be to synthesize the molecule in a lab.
In this report, we try to maximize the pLogP.

Pretraining

The model is initially pretrained on a large set of generic drug-like molecules so that it learns the general rules for molecule generation. The model is provided with a random sub-graph of a graph from the dataset and the model is then expected to reconstruct the graph. For this report, a smaller subset of the ZINC250K dataset consisting of 10K molecules is used for this task. The implementation of the TorchDrug training engine for using W&B for logging is given at the end.
The next code snippet loads the dataset and the GCPN policy and pretrains it for 50 epochs using the new training engine which allows logging the metrics on W&B and saving the model at each epoch as an artifact which can be loaded for the next step. It also creates a table with three molecules from the 50 generated for validation along with the distribution and the mean of the corresponding pLogPs.
from torchdrug import datasets,models, tasks
from torch import optim

dataset = datasets.ZINC250k("~/molecule-datasets/", kekulize=True,
node_feature="symbol")
model = models.RGCN(input_dim=dataset.node_feature_dim,
num_relation=dataset.num_bond_type,
hidden_dims=[256, 256, 256, 256], batch_norm=False)
task = tasks.GCPNGeneration(model, dataset.atom_types, max_edge_unroll=12,
max_node=38, criterion="nll")
optimizer = optim.Adam(task.parameters(), lr=1e-3)
solver = Engine(task, dataset, None, None, optimizer, gpus=(0,), batch_size=128, log_interval=10)
solver.train(num_epoch=50)

Run set
11

At the beginning of the training process, we see that the generated molecules almost exclusively consist of Carbon atoms but as we move towards the end of the training process, we see the addition of substructures which belong to drug-like molecules like the presence of Nitrogen, Oxygen and Sulphur atoms.

Optimization

Once the generator model is pretrained for generic molecules, the next task is to optimize it to generate molecules with high pLogP. In order to do that, proximal policy optimization (PPO) is used to optimize the weights of the graph convolutional policy to maximize the rewards.
In order to do that we first load the latest trained model from the previous step by loading the artifact.
api = wandb.Api()

artifact = api.artifact("GCPN-pLogP/gcpn:latest")
artifact_dir = artifact.checkout()

model = models.RGCN(input_dim=18,
num_relation=3,
hidden_dims=[256, 256, 256, 256], batch_norm=False)
task = tasks.GCPNGeneration(model, [6, 7, 8, 9, 15, 16, 17, 35, 53], max_edge_unroll=12,
max_node=38, task="plogp", criterion="ppo",
reward_temperature=1,
agent_update_interval=3, gamma=0.9)


optimizer = optim.Adam(task.parameters(), lr = 1e-5)
solver = Engine(task, dataset1, None, None, optimizer,
gpus=(0,), batch_size=512, log_interval=10)

solver.model.load_state_dict(torch.load("./artifacts/gcpn:latest/model_gcpn.tar"))
Now that we have a model loaded with the pretrained GCPN, we can go ahead with the RL optimization!
solver.train(num_epoch=20, train_type='rl', wandb_logger=True)

Run set
1

At the end of every epoch, we generate 50 molecules and log their pLogPs to validate the performance of the generator model. The histograms of the pLogPs are seen consitently moving towards the more positive regions as we progress through the epochs which shows that the model learns which substructures should be added to maximize the property.


Run: eager-resonance-6
1


Conclusion

The visualizations showcased in this report make Weights & Biases a really awesome tool in the drug discovery workflow and helps give immense insight into the model training phases. Moreover, the use of artifacts provides an easy way to log the dataset as well as the trained models for reusing in different downstream tasks and also helps in making sure that the experiments can consistently be recreated.

Appendix

Iterate on AI agents and models faster. Try Weights & Biases today.