Intro to Keras with Weights & Biases

Lavanya Shukla


In this tutorial we'll walk through a simple convolutional neural network to classify the images in the Simpson dataset using Keras.

We’ll also set up Weights & Biases to log models metrics, inspect performance and share findings about the best architecture for the network. In this example we're using Google Colab as a convenient hosted environment, but you can run your own training scripts from anywhere and visualize metrics with W&B's experiment tracking tool.

You can find the accompanying code here. We highly encourage you to fork this notebook, tweak the parameters, or try the model with your own dataset!


Start out by installing the experiment tracking library and setting up your free W&B account:

Training A Simple Neural Network

Define Your Hyperparameters

# Initilize a new wandb run
wandb.init(entity="wandb", project="keras-intro")

# Default values for hyper-parameters
config = wandb.config # Config is a variable that holds and saves hyperparameters and inputs
config.learning_rate = 0.01
config.batch_size = 128
config.activation = 'relu'
config.optimizer = 'nadam'

Define Your Neural Network

Below, we define a simplified version of a VGG19 model in Keras, and add the following lines of code to log models metrics, visualize performance and output and track our experiments easily:

# Define the model architecture - This is a simplified version of the VGG19 architecture
model = tf.keras.models.Sequential()

# Set of Conv2D, Conv2D, MaxPooling2D layers with 32 and 64 filters
model.add(tf.keras.layers.Conv2D(filters = 32, kernel_size = (3, 3), padding = 'same',
                 activation ='relu', input_shape = input_shape))
model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2)))

# Flattens our array so we can feed the convolution layer outputs (a matrix) into our fully connected layer (an array)
model.add(tf.keras.layers.Dense(512, activation ='relu'))
model.add(tf.keras.layers.Dense(num_classes, activation = "softmax"))

# Define the optimizer
optimizer = tf.keras.optimizers.Nadam(lr=config.learning_rate, beta_1=0.9, beta_2=0.999, clipnorm=1.0)
model.compile(loss = "categorical_crossentropy", optimizer = optimizer, metrics=['accuracy'])


# Fit the model to the training data
model.fit_generator(datagen.flow(X_train, y_train, batch_size=config.batch_size),
                   steps_per_epoch=len(X_train) / 32, epochs=config.epochs,
                   validation_data=(X_test, y_test), verbose=0,
                   callbacks=[WandbCallback(data_type="image", validation_data=(X_test, y_test), labels=character_names),
                               tf.keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)])

Make Predictions

In this section we make predictions and add wandb.log() to log custom images - in this case our test images with predicted probabilities overlaid on top.

Visualize Predictions Live

Project Overview

  1. Check out the project page to see your results in the shared project.
  2. Press 'option+space' to expand the runs table, comparing all the results from everyone who has tried this script.
  3. Click on the name of a run to dive in deeper to that single run on its own run page.

Visualize Performance

Click through to a single run to see more details about that run. For example, on this run page you can see the performance metrics I logged when I ran this script.

Review Code

The overview tab picks up a link to the code. In this case, it's a link to the Google Colab. If you're running a script from a git repo, we'll pick up the SHA of the latest git commit and give you a link to that version of the code in your own GitHub repo.

Visualize System Metrics

The System tab on the runs page lets you visualize how resource efficient your model was. It lets you monitor the GPU, memory, CPU, disk, and network usage in one spot.

Next Steps

As you can see running sweeps is super easy! We highly encourage you to fork the accompanying notebook, tweak the parameters, or try the model with your own dataset!

Join our mailing list to get the latest machine learning updates.