Exploring Deep Learning Hyperparameters with Random Forests

Lukas Biewald

My colleague Lavanya ran a large hyperparameter sweep on a Kaggle simpsons dataset in colab here. She ran a large search with the intention of finding the best model for the data. In the process of running the sweep she created a lot of hyperparameter data, and I was wondering if I could find useful insights in it.

Here’s a parallel co-ordinates plot visualizing the results of the hyperparameter search. As you can see she tried a lot of different values for epochs, learning rate, weight decay, optimizers and batch size. In this plot we can see how they map to validation accuracy.

Doing hyperparameter searches like these is expensive and there's never time to test every possible combination of inputs, so most people want to find general purpose insights to guide their search. One problem with looking for insights is that hyperparameter searches tend to be messy - the inputs are correlated and there are lots of complicated interactions between them. I was talking to Jeremy Howard at fast.ai and he mentioned how he uses random forest feature importance to explore hyperparameters, which turned out to be really useful on this dataset.

The techniques here are less useful for drawing hard conclusions and more useful as a guide for which hyperparameter values to explore next. All of this analysis could be done on the results of any hyperparameter sweep.

You can find the full code to reproduce this analysis here.

Load the Sweep results

First we need to load the sweep results from the wandb API. Lavanya's sweep results are public at https://app.wandb.ai/sweep/simpsons and we can copy all of the data into pandas dataframes.

Our dataframes are:

  1. config_df - dataframe of hyperparameters (such as optimizer, learning rate)
  2. summary_df - dataframe of output metrics (such as val_loss, val_acc)
  3. name_df - list of names of individual runs

Data Pre-Processing

Next, we use the scikit-learn random forest algorithm. It doesn't accept categorical variables and it doesn't handle NaNs. This is a real world data set and as such some of the hyperparameter values are missing.

We preprocess our data by turning NaNs in hyperparameters into the average for the column, which is probably fine for preliminary exploration. We turn categorical hyperparameters into dummy variables where each value has its own column.

# Remove cases where target is NaN
hyperparams = hyperparams[target.notnull()]
target = target[target.notnull()]

# convert categorical columns into dummy variables
# ie a column with "sgd", "rmsprop" and "adam" will become
# three separate binary columns
hyperparams_categorical = pd.get_dummies(hyperparams)

Plot Correlations

Now we can look at the individual correlation between each of our input hyperparameters and our output variable. This gives a quick look at what values are independently correlated with higher accuracy and what values a negatively correlated with higher accuracy.

# correlations of hyperparams to target
corr_list = hyperparams_categorical.corrwith(target)

Linear Regression

Next up we build a linear model and check which parameters it uses. The coefficients of the linear regression are different from the correlation values for two reasons.

  1. The coefficients are dependent on the absolute magnitude of the inputs, so input values with a small range will have a smaller coefficient.
  2. The coefficients are dependent on each other. So for example the Adam optimizer is independently slightly correlated with higher accuracy, but when all of the other inputs are considered in a linear model it has a small negative effect on accuracy.

The P>|t| value will give a sense of how meaningful the hyperparameter is for predicting loss in our linear model. A low P>|t| gives us confidence that the hyperparameter is directly correlated with validation accuracy given the other variables. In this dataset rmsprop seems to be generally having a negative effect, while higher values of batch size and weight decay have a positive effect.

import statsmodels.api as sm
est = sm.OLS(target, hyperparams_categorical)

Build the Random Forest

The biggest problem with the linear models is that they don't account for interactions between the hyperparameters and we know that hyperparameters can have quite a lot of interaction with each other. A tree model will give us a feature importance metric on how much the given feature (hyperparameter) is used to predict the validation accuracy.

We will use scikit learn's ExtraTreesRegressor where the input is our hyperparameters and the output is our validation accuracy. The ExtraTreesRegressor is a random forest algorithm that decides how many trees to build and how large to make the trees. Trees are nice for this case because they are scale invariant. This means we can more safely use this approach on many different types of hyperparameters at once.

from sklearn.ensemble import ExtraTreesRegressor
forest = ExtraTreesRegressor(n_estimators=250,
forest.fit(hyperparams_categorical, target)
importances = forest.feature_importances_


We can see that learning rate had the highest importance, but did lot have a low P value in our regression model. This means that it has a complicated interaction with the other hyperparameters we used. We can probably conclude that elu vs relu had limited importance across the range of parameters that we tried. Using weight decay was generally correlated with good outcomes.

I encourage you to run a hyperparameter sweep on your own dataset then use this analysis as a guide to find which hyperparameter values to explore next.

Good luck!

Join our mailing list to get the latest machine learning updates.