Integrating Keras with Weights & Biases
A step-by-step tutorial where we'll train a simple image classifier and show you how to use Weights & Biases in your Keras projects
Created on January 26|Last edited on March 16
Comment
Note that W&B has shipped new Keras Callbacks - WandbMetricsLogger, WandbModelCheckpoint and WandbEvalCallback to make it easier to integrate Weights and Biases with your Keras workflow. This tutorial showcase how you can use our legacy WandbCallback in-depth.
💡
Table of Contents (Click to Expand)
👋 Introduction
In this report, we'll walk you through how to integrate Weights & Biases in your TensorFlow/Keras workflow. We'll provide code, instructions, some Tables and plots you can interact with to evaluate your model, and a whole lot more. Our goal is simple: to show you how W&B and TensorFlow/Keras can work seamlessly together.
We'll be doing this with the help of the bloodMNIST dataset, part of the larger MedMNIST dataset. Specifically, we'll:
- train an image classifier for this dataset using TensorFlow/Keras,
- use W&B Tables to explore the dataset and evaluate the trained classifier.
- and show how to leverage WandbCallback for experiment tracking and model evaluation.
Last thing before we jump in. If you want to view this report as a colab with executable code, just follow this link.
First, Let's Install Our Dependencies
For the purposes of this Report, we'll assume you're starting from scratch. That means, we'll start by installing Weights & Biases. It's pretty easy.
# For Weights and Biases!pip install -qq wandb
And you're done. (Told you it was easy.)
Next, let's download our dataset:
# To download the dataset# For medMNIST datasetimport medmnistprint("medMNIST: ", medmnist.__version__)from medmnist import INFO
Finally, beside other imports, import the medmnist package.
# For medMNIST datasetimport medmnistprint("medMNIST: ", medmnist.__version__)from medmnist import INFO
If this is your first time using W&B or you are not logged in, the link that appears after running wandb.login() will take you to sign-up/login page. Signing up for a free account is as easy as a few clicks.
# Login to W&Bwandb.login()
Next up: Configs
Configuration files in .yaml or .json format are an integral part of most mature machine learning systems. After all, keeping the track of hyperparameters used to train and evaluate your model is essential for reproducing your experiments.
W&B can keep track of all these configs. Here we will first define all the hyperparameters needed for training as a dictionary for this particular tutorial:
configs = dict(data_flag = 'bloodmnist',image_width = 32,image_height = 32,batch_size = 128,model_name = 'vgg16',pretrain_weights = 'imagenet',epochs = 100,init_learning_rate = 0.001,lr_decay_rate = 0.1,optimizer = 'adam',loss_fn = 'sparse_categorical_crossentropy',metrics = ['acc'],earlystopping_patience = 5)
Download and Prepare the Dataset
MedMNIST is a large-scale MNIST-like collection of standardized biomedical images, including twelve 2D datasets and six 3D datasets. All the images here are pre-processed to the size of 28x28 and running this experiment requires no prior domain knowledge to start with.
As we mentioned above, in this tutorial, we will be using BloodMNIST dataset. From the dataset description:
The BloodMNIST is based on a dataset of individual normal cells, captured from individuals without infection, hematologic or oncologic disease and free of any pharmacologic treatment at the moment of blood collection. It contains a total of 17,092 images and is organized into 8 classes. We split the source dataset with a ratio of 7:1:2 into training, validation and test set. The source images with resolution 3×360×363 pixels are center-cropped into 3×200×200, and then resized into 3×28×28.
💡
Next, run this code:
info = INFO[configs['data_flag']]configs['class_names'] = info['label']configs['image_channels'] = info['n_channels']info
Each MedMNIST dataset can be downloaded using the download_and_prepare_dataset function below and the downloaded dataset is in the .npz format.
Each subset (e.g., bloodmnist.npz) is comprised of six keys: train_images, train_labels, val_images, val_labels, test_images and test_labels.
#Prepping our datasetdef download_and_prepare_dataset(data_info: dict):"""Utility function to download the dataset and return train/valid/test images/labels.Arguments:data_info (dict): Dataset metadata"""data_path = tf.keras.utils.get_file(origin=data_info['url'], md5_hash=data_info['MD5'])with np.load(data_path) as data:# Get imagestrain_images = data['train_images']valid_images = data['val_images']test_images = data['test_images']# Get labelstrain_labels = data['train_labels'].flatten()valid_labels = data['val_labels'].flatten()test_labels = data['test_labels'].flatten()return train_images, train_labels, valid_images, valid_labels, test_images, test_labels
train_images, train_labels, valid_images, valid_labels, test_images, test_labels = download_and_prepare_dataset(info)
Explore the Dataset using W&B Tables
As a TensorFlow/Keras user, you might be familiar with the show_batch function. Or you might have written some matplotlib-based code to visualize few batches of dataset. This is good for quick inspection of the dataset but for most real life scenario it's not enough.
Here we will use W&B Tables (wandb.Table) to log the training data and visualize and query interactively with W&B. As the name suggests it is a table of data specified by you.
You can log data to W&B Tables row-wise or column-wise. In the section below, we have created the table column wise. Use add_column to define the name of the column and provide array of data associated with that column. Simply adding image array will not render in the W&B Tables UI. You will have to wrap each image array with wandb.Image. To do so, add_computed_columns is used. You can learn about these methods here.
Finally, note that W&B Tables is built on top of W&B Artifacts, which can be viewed as a file (usually for dataset and models) storage system in W&B. In this section, we have explicitly initialized an Artifact using wandb.Artifact and have added both the train_table and validation_table to the artifact. Alternatively, we could have prepared the table and logged it using wandb.log. Here's a quick example if you are interested.
Here's a look first at the train Table. We'll show you the code right after you see what we'll be making.
And then the validation Table.
Ok. Let's learn how. First:
# For demonstration purposeslog_full = Falseif log_full:log_train_samples = len(train_images)else:log_train_samples = 1000
Note that if you want to log the entire dataset as Tables turn log_full to True. Next:
# Initialize a new W&B runrun = wandb.init(project='medmnist-bloodmnist', group='viz_data')# Intialize a W&B Artifactsds = wandb.Artifact("medmnist_bloodmnist_dataset", "dataset")# Initialize an empty tabletrain_table = wandb.Table(columns=[], data=[])# Add training datatrain_table.add_column('image', train_images[:log_train_samples])# Add training label_idtrain_table.add_column('label_id', train_labels[:log_train_samples])# Add training class namestrain_table.add_computed_columns(lambda ndx, row:{"images": wandb.Image(row["image"]),"class_names": configs['class_names'][str(row["label_id"])]})# Add the table to the Artifactds['train_data'] = train_table# Let's do the same for the validation datavalid_table = wandb.Table(columns=[], data=[])valid_table.add_column('image', valid_images)valid_table.add_column('label_id', valid_labels)valid_table.add_computed_columns(lambda ndx, row:{"images": wandb.Image(row["image"]),"class_name": configs['class_names'][str(row["label_id"])]})ds['valid_data'] = valid_table# Save the dataset as an Artifactds.save()# Finish the runwandb.finish()
Let's build a Data Pipeline
Here, we'll use tf.data.Dataset to build our data pipeline:
@tf.functiondef preprocess(image: tf.Tensor, label: tf.Tensor):"""Preprocess the image tensors and parse the labels"""# Preprocess imagesimage = tf.image.convert_image_dtype(image, tf.float32)# Parse labellabel = tf.cast(label, tf.float32)return image, labeldef prepare_dataloader(images: np.ndarray,labels: np.ndarray,loader_type: str='train',batch_size: int=128):"""Utility function to prepare dataloader."""dataset = tf.data.Dataset.from_tensor_slices((images, labels))if loader_type=='train':dataset = dataset.shuffle(1024)dataloader = (dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE).batch(batch_size).prefetch(tf.data.AUTOTUNE))return dataloader
Let's initialize our dataloaders.
trainloader = prepare_dataloader(train_images, train_labels, 'train', configs.get('batch_size', 64))validloader = prepare_dataloader(valid_images, valid_labels, 'valid', configs.get('batch_size', 64))testloader = prepare_dataloader(test_images, test_labels, 'test', configs.get('batch_size', 64))
Visualize Different Augmented View [Optional]
Here, let's use W&B Tables to visualize augmented images of a subset of training images.
Augmentation policies should make sense for the given task. By using W&B Tables here we can visualize how the original images are augmented. For the sake of simplicity, we will just be visualizing the first 100 images.
def augment_5_times(img):augmented_imgs = []for _ in range(5):aug_img = tf.squeeze(img_augmentation(img), axis=0)# Notice the use of wrapping the images with wandb.Imagewandb_image = wandb.Image(aug_img.numpy())augmented_imgs.append(wandb_image)return augmented_imgs
We can download the dataset that we have logged as W&B Tables as shown in the code cell below. Since Tables are saved as W&B Artifacts, we first need to pass in the name (path as shown in the UI) of the artifact to use_artifact. You can find the name if you head over to the artifact tab on the W&B dashboard and click on the API panel.
Get the required table by using the get method and provide the name of the table. Use the get_column method get the data associated with that column. Here, the augment_table is initialized with the column names and data are added row-wise iteratively.
viz_augment_samples = 100# Initialize a W&B runrun = wandb.init(project='medmnist-bloodmnist', group='viz_augmentation')# Use the already logged datasettrain_art = run.use_artifact('ayush-thakur/medmnist-bloodmnist/medmnist_bloodmnist_dataset:latest', type='dataset')# Get the train_table to access the datatrain_table = train_art.get("train_data")# Get the images, ground truth label, and row indeximages = train_table.get_column("images", convert_to="numpy")labels = train_table.get_column("label_id", convert_to="numpy")ids = train_table.get_index()# Shuffle the ids and slicerandom.shuffle(ids)sample_ids = ids[0:viz_augment_samples]# Create augmentation tableaugment_table = wandb.Table(columns=['image', 'truth', 'label_id', 'aug1', 'aug2', 'aug3', 'aug4', 'aug5'])# Get augmented images and log it onto the tablefor sample_id in sample_ids:img = images[sample_id]label = labels[sample_id]augmented_imgs = augment_5_times(tf.expand_dims(img, axis=0))augment_table.add_data(wandb.Image(img),np.argmax(label),configs['class_names'][str(label)],augmented_imgs[0],augmented_imgs[1],augmented_imgs[2],augmented_imgs[3],augmented_imgs[4])# Log the tablewandb.log({'augmented data': augment_table})# Finish the runwandb.finish()
Here's an example of the logged table – you'll see the augmented images as columns below:
Let's get modeling:
Modeling with W&B
def get_model(input_shape: tuple=(28, 28, 3),resize: tuple=(32, 32, 3),dropout_rate: float=0.5,num_classes: int=8,output_activation: str='softmax'):inputs = layers.Input(input_shape)resize_img = layers.Resizing(resize[0], resize[1], interpolation='bilinear')(inputs)augment_img = img_augmentation(resize_img)base_model = tf.keras.applications.VGG16(include_top=False,weights=configs['pretrain_weights'],input_shape=resize,input_tensor=augment_img)base_model.trainabe = Truex = base_model.outputx = layers.GlobalAveragePooling2D()(x)x = layers.Dropout(dropout_rate)(x)outputs = layers.Dense(num_classes, activation=output_activation)(x)return models.Model(inputs, outputs)tf.keras.backend.clear_session()model = get_model()
📞 Callback
We will regularize our training process using Early Stopping. Let's define our early stopping callback. We'll define the WandbCallback later.
earlystopper = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=configs['earlystopping_patience'], verbose=0, mode='auto',restore_best_weights=True)
You can use wandb.log to log any useful metric/parameter that's not logged by WandbCallback. Here we are using a learning rate scheduler to exponentially decay the learning rate after 10 epochs. Notice the use of wandb.log to capture the learning rate and commit=False in particular.
def lr_scheduler(epoch, lr):# log the current learning rate onto W&Bif wandb.run is None:raise wandb.Error("You must call wandb.init() before WandbCallback()")wandb.log({'learning_rate': lr}, commit=False)if epoch < 7:return lrelse:return lr * tf.math.exp(-configs['lr_decay_rate'])lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_scheduler)
🚉 Train
The train function below encapsulates the training recipe for our classifier.
def train(config: dict,callbacks: list,verbose: int=0):"""Utility function to train the model.Arguments:config (dict): Dictionary of hyperparameters.callbacks (list): List of callbacks passed to `model.fit`.verbose (int): 0 for silent and 1 for progress bar."""# Initalize modeltf.keras.backend.clear_session()model = get_model(resize=(config.image_width, config.image_height, config.image_channels))# Compile the modelopt = tf.keras.optimizers.Adam(learning_rate=config.init_learning_rate)model.compile(opt,config.loss_fn,metrics=config.metrics)# Train the model_ = model.fit(trainloader,epochs=config.epochs,validation_data=validloader,callbacks=callbacks,verbose=verbose)return model
Train using WandbCallback
In the section below we will train our classifier using WandbCallback to log all the training and validation metrics to a wandb dashboard by default.
WandbCallback enables to you keep track of your experiments, saves the best model, and helps visualize model performance with just one line of code.
In the code below, we'll use the following arguments:
- monitor = 'val_loss' will monitor the val_loss to save the best model. (Please note that 'val_loss' is the default value for monitor.)
- log_weights = True will save histograms of the weights of our model's layers.
- log_evaluation = True will create a W&B Table of validation data and model prediction. The number of validation samples is controlled by validation_steps if a generator is passed to model.fit.
# Initialize the W&B runrun = wandb.init(project='medmnist-bloodmnist', config=configs, job_type='train')config = wandb.config# Define WandbCallback for experiment trackingwandb_callback = WandbCallback(monitor='val_loss',log_weights=True,log_evaluation=True,validation_steps=5)# callbackscallbacks = [earlystopper, wandb_callback, lr_callback]# Trainmodel = train(config, callbacks=callbacks, verbose=1)# Evaluate the trained modelloss, acc = model.evaluate(validloader)wandb.log({'evaluate/accuracy': acc})# Close the W&B run.wandb.finish()
Run set
1
🌱 Advanced Usage
In this section, we'll look at an advance usage of WandbCallback namely the use of validation_row_processor and prediction_row_processor.
We will use WandbCallback to log the GradCAM for each validation examples along with ground truth labels and model predictions. You can check out my report Interpretability in Deep Learning With W&B - CAM and GradCAM to learn more about GradCAM.
We will be using this tutorial on GradCAM by François Chollet. You can check out the colab to see how GradCAM is computed. The create_gradcam function requires the name of the layer whose output will be used to compute the GradCAM.
last_conv_layer_name = 'block4_conv3'
WandbCallback's validation_row_processor and prediction_row_processor can be used to apply a function to validation data and model output respectively. In the example shown below these processors are used to log the images, ground truth labels, model predictions and the GradCAM for model interpretability.
The processors take a callable function that receives an ndx (index of row) and a row (dict of data). The validation_processor function below receives the input image array along with target label as row dict. The validation_row_processor applies the function on the validation data which is executed when WandbCallback is initialized (i.e, before model training). The validation_row_processor creates a table with two columns namely input:image and target:class.
def validation_processor(ndx, row):return {"input:image": wandb.Image(row["input"]),"target:class": class_table.index_ref(row["target"])}
The prediction_processor function receives the model output prediction (it can be logits, masks, reconstructed images, etc.) and the validation data row index (basically index of the table). The prediction_row_processor applies the function on the model output once the training is over. Notice that in the prediction_processor function we are accessing the logged image at a given val_row using the get_row method.
Since the validation_row_processor is executed before training begins, we have a table with two columns. One column has logged images while the other has ground truth labels. These columns (and associated data) can be accessed during prediction_row_processor call by using get_row method as shown.
💡
def prediction_processor(ndx, row):# Get the validation imagevalid_image = np.array(row["val_row"].get_row()["input:image"].image)return {"output:class": class_table.index_ref(np.argmax(row["output"])),"gradcam": wandb.Image(create_gradcam(valid_image, model, last_conv_layer_name)),"output:logits": {class_name: value for (class_name, value) in zip(list(config.class_names.values()), row["output"].tolist())}}
The code block below shows how the processors are used to log the GradCAM of the model. We will first download the validation data logged as W&B Tables.
We have also initialized a class_table which is nothing but a fancy way to map the target integer ids to string labels. The validation_processor uses class_table.index_ref(row["target"]) to get the string label for the given target.
Notice the use of lambda function to pass ndx and row to the respective functions.
# Initialize the W&B runrun = wandb.init(project='medmnist-bloodmnist', config=configs, job_type='train')config = wandb.config# Get validation tabledata_art = run.use_artifact('ayush-thakur/medmnist-bloodmnist/medmnist_bloodmnist_dataset:latest', type='dataset')valid_table = data_art.get("valid_data")# Create a class tableclass_table = wandb.Table(columns=[], data=[])class_table.add_column("class_name", list(config.class_names.values()))# Define WandbCallback for experiment trackingwandb_callback = WandbCallback(log_evaluation=True,validation_row_processor=lambda ndx, row: validation_processor(ndx, row),prediction_row_processor=lambda ndx, row: prediction_processor(ndx, row),validation_steps=4,save_model=False)# callbackscallbacks = [earlystopper, wandb_callback, lr_callback]# Trainmodel = train(config, callbacks=callbacks, verbose=1)# Evaluate the trained modelloss, acc = model.evaluate(validloader)wandb.log({'evaluate/accuracy': acc})# Close the W&B run.wandb.finish()
The W&B Table shown below is the result of using the processors. You can see the GradCAM of each image along with ground truth label and model prediction.
Run set
1
🌾 Conclusion
Weights and Biases's Keras integration enables experiment tracking and plenty more with just few lines of code. In this notebook, we've seen some advanced usage of Keras WandbCallback and different ways of using W&B Tables for evaluation and data exploration.
To sum up all you need is a free W&B account, import the WandbCallback and pass it to model.fit(callbacks=[.]) just like any callback. There are a few more arguments that you can learn about in the documentation page here. Of particular note:
- You can log the metrics for each batch by setting log_batch_frequency=1
- You can log the gradients of each layer to debug vanishing or exploding gradient issue by setting log_gradients=True. You will also have to provide the training_data in the format of (X, y).
- if your task is semantic segmentation you can set input_type=segmentation_mask.
If your use case is not covered by the WandbCallback you can easily write a custom Keras callback and use wandb.log to log the required data to W&B dashboard.
And, if you have any questions, feel free to drop them in the comments!
Add a comment
Tags: Beginner, Computer Vision, Classification, Keras, W&B Meta, Artifacts, Panels, Plots, Tables, MedMNIST
Iterate on AI agents and models faster. Try Weights & Biases today.