DeepChem: Molecular Solubility

Predict chemical properties from molecular structure with random forests and deep nets. Made by Stacey Svetlichnaya using Weights & Biases
Stacey Svetlichnaya


In biochemical research and drug discovery, we may want to know how much of, or to what extent, a molecule will dissolve in a particular solution. Given the chemical structure of a molecule, can we train a machine learning model to predict how soluble that molecule is in aqueous solution? Here I explore the task of molecular solubility, following an excellent tutorial from the DeepChem project, which aims to open-source deep learning for science. Below are the names and chemical structures of a few compounds from the training data.

Data Setup: Featurize Chemical Structure

The original dataset [1] of 1128 chemical compounds maps their chemical structure to a list of molecular properties, including the target metric: ESOL predicted log solubility in mols/liter. Plotting a histogram of the data below shows a fairly normal distribution of solubility, which our model will try to capture.
There are several preprocessing steps to extract feature vectors from these chemical structures, which are represented as strings in "SMILES", or "simplified molecular-input line-entry system" format, e.g. "c2ccc1scnc1c2"):
The data splitting puts more common scaffolds into "train" and more rare scaffolds into "validation"—see if you can spot a visual distinction between the two in the panels below.
[1] John S. Delaney. ESOL: Estimating aqueous solubility directly from molecular structure. Journal of Chemical Information and Computer Sciences, 44(3):1000–1005, 2004.

Fitting with a Random Forest

First we try fitting the data with a classical machine learning approach: random forests. Following the tutorial, I vary:

Learning curve (plot with wandb.sklearn)

From the learning curve plotted above, we can see that all the random forest variants train similarly, quickly overfitting the dataset (converging almost to 1). The test accuracy is much lower, but increases with more data, which is encouraging. Increasing the number of estimators from 10 to 50 to 100 increases training accuracy more noticeably. This curve is easy to generate with

wandb.sklearn.plot_learning_curve(rf_model.model_instance, train_dataset.X, train_dataset.y)

Measuring R-squared( R^2) on validation data

The best random forest (highest R^2 value) uses 100 Estimators and log2 max features (shown in green). The R-squared values (above, right) score the performance of the models on the validation dataset after training on the full training data, so the ordering across model performance on the left and right charts varies slightly.

Fitting with a simple deep net

Deep nets can outperform random forests

Deepchem wraps a fully-connected network as a dc.models.MultitaskRegressor. Doing a brief hyperparameter search on these quickly reaches some combinations that exceed the random forest models, and further sweeps should yield even better models. The networks shown here only have one fully-connected layer of variable size—the larger size unsurprisingly tends to do better. The best runs appear at the top of the bar chart above, which is sorted in deceasing order of performance. A layer size of 100 actually seems to yield a negative R^2 score (which is technically possible in the sklearn definition at least).

Hyperparameter exploration for simple deep net

I ran a W&B Sweep from the simple fully-connected network provided to explore the space of possibilities. You can see the time course of the sweep below, with the best overall model variants at each time step connected in blue.

Some hyperparameters I varied in this initial exploration

To the right, you can see an automatically generated parameter importance panel, which ranks the relative contribution of these hyperparameters to a higher R^2 score, which is the metric I'm trying to optimize. Recommendations from this panel:

Diving deeper on layer configuration

In the section above, the parallel coordinates plot is not obviously informative about layer configuration: the proportion of better (more yellow) versus worse (more blue) runs is about equal in each node. Grouping the runs by Layer configuration gives me more visibility. Below, each line is a set of runs with the same layer configuration. E.g., "layers: "200 100"" means a first fully-connected layer of 200 neurons, and a second fully-connected layer of 100, while "layers: 1000" means a single fully-connected layer of 1000. You can see the total runs in each group to the right of the group name, and the resulting average R^2 score for that group immediately to the right in the "r2" column.

Free variables: the more, the better

Observations from this plot:
Next step: running a more precise sweep on the promising layer combinations—more layers, larger first layers.