Keras-Tuner with W&B

Integrating wandb with the keras-tuner. Made by Aritra Roy Gosthipaty using Weights & Biases
Aritra Roy Gosthipaty


Check out the Kaggle Notebook

An artificial neural network is made up of many prior constraints, weights and biases. These constraints, i.e., the number of neurons, the choice of activation (non-linearity), the number of layers, are commonly termed as hyper-parameters. A vast field of research is based on hyper-parameter optimization. This means people are interested in not only turning the knobs of the weights and biases but also that of the hyper-parameters. There are some great approaches (Grid, Random, Bayesian, to name some), which have already marked this field.

A large amount of time for Deep Learning experimentation is spent on choosing good hyperparameters. The choice of good hyperparameters can sometimes be game-changers for the experiment. This topic is widely studied and researched. With the advent of various search algorithms, we can tune the hyperparameters automatically. The concept of tuning hyperparameters by searching a hyperparameter space automatically has helped reduce the time of DL researchers who were doing it manually.

In this article, we will be looking into one such tool that helps in the automation of hyper-parameter tuning, the keras-tuner. We will not only understand the basics of the tool but also try integrating it with our favourite experiment tracker wandb.

The following will be covered:

The API of keras-tuner

The keras team always puts a lot of effort into the API design of their tools. This tool does not stray away from a similar thought process.

There are four basic interfaces that the API provides. These interfaces are the heart of the API.

  1. HyperParameters: This class serves as a hyperparameter container. An instance of this class contains information about the present hyperparameters and the search space in total.
  2. Hypermodel: An instance of this class can be thought of as an object that models the entire hyperparameter space. The instance not only builds the hyperparameter space but also builds DL models sampling from the hyperparameters.
  3. Oracles: Each instance of this class implements a particular hyperparameter tuning algorithm.
  4. Tuners: A Tuner instance does the hyperparameter tuning. An Oracle is passed as an argument to a Tuner. The Oracle tells the Tuner which hyperparameters should be tried next.

The top-down approach to the API design makes it readable and easy to understand. To iterate it all:

Code with keras-tuner

In this section, I will try explaining the basic usage of keras-tuner with an example. The example is taken from their own documentation.

Leaving aside the imports that are necessary to run the tuner, we need to first build the Hypermodel that will emulate the entire search space.

We can build a Hypermodel in two ways:

  1. Build models with a function
  2. Subclass from the Hypermodel class


Here we build a function that takes HyperParameters as an argument. The function samples from the HyperParameters and builds models and returns them. This way different models are made from the search space.

# build with function
def build_model(hp):
  model = keras.Sequential()
  model.add(layers.Dense(10, activation='softmax'))
                    values=[1e-2, 1e-3, 1e-4])),
  return model

Subclassing the Hypermodel class

With this method, one needs to override the build() method. In the build() method the user can sample from the HyperParameters and build suitable models.

# build with inheritance
class MyHyperModel(HyperModel):

  def __init__(self, num_classes):
    self.num_classes = num_classes

  def build(self, hp):
    model = keras.Sequential()
    model.add(layers.Dense(self.num_classes, activation='softmax'))
                      values=[1e-2, 1e-3, 1e-4])),
    return model

In both the case, a Hypermodel is created by providing HyperParameters. An interested reader is advised to look into the way the hyperparameters are sampled. The package not only provides static choices but also provides conditional hyperparameters.

After we have our Hypermodel ready, it is time to build the Tuner. Tuner search the hyperparameter space and gives us the most optimised set of hyperparameters. Below I have written the tuners for both the Hypermodel setting.

# tuner for function
tuner = RandomSearch(

# tuner for subclass
hypermodel = MyHyperModel(num_classes=10)
tuner = RandomSearch(

Note: With the custom Tuner one needs to pass the tuner an Oracle that helps the tuner with the searching algorithm.

With everything set, we are good to run the search. The search method follows the same design as the fit method does. After search we can query the tuner for the best model and also the hyperparameters., y,
             validation_data=(val_x, val_y))

Code to Integrate keras-tuner with wandb

Check out the Kaggle Notebook

How cool would it be to track all the models in one place along with keras-tuner? Here we would integrate wandb with our keras-tuner to help track all the models that are created and searched through. This will not only help with retrieving the best model but also will provide some insights that are of high value.

In this section, we will run a modified code for subclassing of keras-tuner.


Here we take the functional way to build the Hypermodel. This serves as an extremely easy way to build models.

In this example, one can see that the usage of conditional hyperparameters is implemented. We have a for loop creating a tunable number of conv_layers, which themselves involve a tunable filters and kernel_size parameter.

def build_model(hp):
    Builds a convolutional model.
      hp: Hyperparamet object, This is the object that helps
        us sample hyperparameter for a particular trial.
      model: Keras model, Returns a keras model.
    inputs = tf.keras.Input(shape=(28, 28, 1))
    x = inputs
    # In this example we also get to look at
    # conditional heyperparameter settings.
    # Here the `kernel_size` is conditioned
    # with the for loop counter. 
    for i in range(hp.Int('conv_layers', 1, 3)):
      x = tf.keras.layers.Conv2D(
          filters=hp.Int('filters_' + str(i), 4, 32, step=4, default=8),
          kernel_size=hp.Int('kernel_size_' + str(i), 3, 5),
      # choosing between max pool and avg pool
      if hp.Choice('pooling' + str(i), ['max', 'avg']) == 'max':
        x = tf.keras.layers.MaxPooling2D()(x)
        x = tf.keras.layers.AveragePooling2D()(x)
      x = tf.keras.layers.BatchNormalization()(x)
      x = tf.keras.layers.ReLU()(x)

    if hp.Choice('global_pooling', ['max', 'avg']) == 'max':
      x = tf.keras.layers.GlobalMaxPooling2D()(x)
      x = tf.keras.layers.GlobalAveragePooling2D()(x)
    outputs = tf.keras.layers.Dense(10, activation='softmax')(x)
    model = tf.keras.Model(inputs, outputs)
    return model


Integrating the tuner to log the config and loss with wandb was a piece of cake. The API provides the user to override the run_trial method of the kt.Tuner class. In the run_trial method, one can harness the HyperParameters object. This is used to query the present hyperparameters as config of a wandb run. Not only does this mean that now we can log the metrics of the models, but we can also compare the hyperparameters with the help of great widgets that wandb provides in their dashboard.

class MyTuner(kt.Tuner):
  Custom Tuner subclassed from `kt.Tuner`
  def run_trial(self, trial, train_ds):
    The overridden `run_trial` function

      trial: The trial object that holds information for the
        current trial.
      train_ds: The training data.
    hp = trial.hyperparameters
    # Batching the data
    train_ds = train_ds.batch(
        hp.Int('batch_size', 32, 128, step=32, default=64))
    # The models that are created
    model =
    # Learning rate for the optimizer
    lr = hp.Float('learning_rate', 1e-4, 1e-2, sampling='log', default=1e-3)

    if hp.Choice('optimizer', ['adam', 'sgd']) == 'adam':
      optimizer = tf.keras.optimizers.Adam(lr)
      optimizer = tf.keras.optimizers.SGD(lr)

    epoch_loss_metric = tf.keras.metrics.Mean()

    # build the train_step
    def run_train_step(data):
      The run step

        data: the data that needs to be fit
        loss: Returns the loss for the present batch
      images = tf.dtypes.cast(data['image'], 'float32') / 255.
      labels = data['label']
      with tf.GradientTape() as tape:
        logits = model(images)
        loss = tf.keras.losses.sparse_categorical_crossentropy(
            labels, logits)
      gradients = tape.gradient(loss, model.trainable_variables)
      optimizer.apply_gradients(zip(gradients, model.trainable_variables))
      return loss
    # Here we pass the configuration so that
    # the runs are tagged with the hyperparams
    # This also directly means that we can
    # use the different comparison UI widgets in the 
    # wandb dashboard off the shelf.
    run = wandb.init(entity='ariG23498', project='keras-tuner', config=hp.values)
    for epoch in range(10):
      self.on_epoch_begin(trial, model, epoch, logs={})
      for batch, data in enumerate(train_ds):
        self.on_batch_begin(trial, model, batch, logs={})
        batch_loss = run_train_step(data)
        self.on_batch_end(trial, model, batch, logs={'loss': batch_loss})   
        if batch % 100 == 0:
          loss = epoch_loss_metric.result().numpy()
          # Log the batch loss for WANDB
      # Epoch loss logic
      epoch_loss = epoch_loss_metric.result().numpy()
      # Log the epoch loss for WANDB
      run.log({'epoch_loss':epoch_loss, 'epoch':epoch})
      # `on_epoch_end` has to be called so that 
      # we can send the logs to the `oracle` which handles the
      # tuning.
      self.on_epoch_end(trial, model, epoch, logs={'loss': epoch_loss})
    # Finish the wandb run

Section 6


I would advise my readers to quickly spin up a notebook and try a great tool for themselves. For future references, one can go and read the great docs for keras-tuner.

The topic of hyperparameter tuning is so vastly researched that people have also tried incorporating genetic algorithm and used the concept of evolving the models similar to us creatures. A shameless plug here would be to link the interested reader to one of my articles that deconstructs the concept of hyperparameter tuning with Genetic Algorithm.

Connect with me over Twitter @ariG23498