Object Localization With Keras and Weights & Biases
This article explores object localization using the bounding box regression technique in Keras and interactively visualizes the model's prediction in Weights & Biases
Created on October 8|Last edited on March 14
Comment
Object localization is the task of locating an instance of a particular object category in an image, typically by specifying a tightly cropped bounding box centered on the instance. Object detection, on the contrary, is the task of locating all the possible instances of all the target objects.
Object localization is also called "classification with localization". This is because the architecture which performs image classification can be slightly modified to predict the bounding box coordinates. Check out Andrew Ng's lecture on object localization or check out Object detection: Bounding box regression with Keras, TensorFlow, and Deep Learning by Adrian Rosebrock.

Figure 1: Difference between image classification, object localization and object detection. (Source)
In this article, we will build an object localization model and train it on a synthetic dataset. We will interactively visualize our models' predictions in Weights & Biases.
Table of Contents
The Dataset
We will use a synthetic dataset for our object localization task based on the MNIST dataset. This dataset is made by Laurence Moroney. The idea is that instead of 28x28 pixel MNIST images, it could be NxN(100x100), and the task is to predict the bounding box for the digit location.

Figure 2: Samples from the dataset. Every image is 100x100 pixels.
Download the Dataset
This GitHub repo is the original source of the dataset. However, due to this issue, we will use my fork of the original repository.
We also have a .csv training and testing file with the name of the images, labels, and the bounding box coordinates. Note that the coordinates are scaled to [0, 1].
!git clone https://github.com/ayulockin/synthetic_datasets%cd synthetic_datasets/MNIST/%mkdir images!unzip -q MNIST_Converted_Training.zip -d images/!unzip -q MNIST_Converted_Testing.zip -d images/
Dataloader Using tf.data
We will use tf.data.Dataset to build our input pipeline. Our model will have to predict the class of the image(object in question) and the bounding box coordinates given an input image.
In the model section, you will realize that the model is a multi-output architecture. Check out Keras: Multiple outputs and multiple losses by Adrian Rosebrock to learn more about it.
The tf.data.Dataset pipeline shown below addresses multi-output training. We will return a dictionary of labels and bounding box coordinates along with the image. The name of the keys should be the same as the name of the output layers.
AUTO = tf.data.experimental.AUTOTUNEBATCH_SIZE = 32@tf.functiondef preprocess_train(image_name, label, bbox):image = tf.io.read_file(TRAIN_IMG_PATH+image_name)image = tf.image.decode_png(image, channels=1)return image, {'label': label, 'bbox': bbox} # Notice here@tf.functiondef preprocess_test(image_name, label, bbox):image = tf.io.read_file(TEST_IMG_PATH+image_name)image = tf.image.decode_png(image, channels=1)return image, {'label': label, 'bbox': bbox} # Notice heretrainloader = tf.data.Dataset.from_tensor_slices((train_image_names, train_labels, train_bbox))testloader = tf.data.Dataset.from_tensor_slices((test_image_names, test_labels, test_bbox))trainloader = (trainloader.map(preprocess_train, num_parallel_calls=AUTO).shuffle(1024).batch(BATCH_SIZE).prefetch(AUTO))testloader = (testloader.map(preprocess_test, num_parallel_calls=AUTO).batch(BATCH_SIZE).prefetch(AUTO))
The Model
Bounding Box Regression
Before we build our model, let's briefly discuss bounding box regression. In machine learning literature regression is a task to map the input value with the continuous output variable.
Thus we return a number instead of a class, and in our case, we're going to return 4 numbers (,,,) that are related to a bounding box. We will train this system with an image and a ground truth bounding box, and use loss to calculate the loss between the predicted bounding box and the ground truth. Check out this video to learn more about bounding box regression.

Figure 3: The general model architecture for bounding box regression for object localization task. (Source)
Going back to the model, figure 3 rightly summarizes the model architecture. The model constitutes three components - convolutional block(feature extractor), classification head, and regression head.
This is a multi-output configuration. As mentioned in the dataset section, the tf.data.Dataset input pipeline returns a dictionary, whose key names are the name of the output layer of the classification head and the regression head.
The code snippet shown below builds our model architecture for object localization.
def get_model():inputs = Input(shape=(100,100,1))x = Conv2D(32, (3,3), activation='relu')(inputs)x = MaxPooling2D((3,3))(x)x = Conv2D(32, (3,3), activation='relu')(x)x = MaxPooling2D((3,3))(x)x = Conv2D(64, (3,3), activation='relu')(x)x = GlobalAveragePooling2D()(x)classifier_head = Dropout(0.3)(x)# Notice the name of the layer.classifier_head = Dense(10, activation='softmax', name='label')(classifier_head)reg_head = Dense(64, activation='relu')(x)reg_head = Dense(32, activation='relu')(reg_head)# Notice the name of the layer.reg_head = Dense(4, activation='sigmoid', name='bbox')(reg_head)return Model(inputs=[inputs], outputs=[classifier_head, reg_head])
The names given to the multiple heads are used as keys for the losses dictionary. Note that the activation function for the classification head is softmax since it's a multi-class classification setup(0-9 digits). The activation function for the regression head is sigmoid since the bounding box coordinates are in the range of [0, 1].
The loss functions are appropriately selected. We can optionally give different weightage to different loss functions.
losses = {'label': 'sparse_categorical_crossentropy','bbox': 'mse'}loss_weights = {'label': 1.0,'bbox': 1.0}
The BBoxLogger - Interactive Visualization of Predictions
When working on object localization or object detection, you can interactively visualize your models' predictions in Weights & Biases. You can log the sample images along with the ground truth and predicted bounding box values. You can even log multiple boxes and can log confidence scores, IoU scores, etc. Check out the documentation here.
The article Bounding Boxes for Object Detection by Stacey Svetlichnaya walks you through the interactive controls for this tool. It covers the various nuisances of logging images and bounding box coordinates.
The code snippets shown below is the helper function for our BBoxLogger callback. The function wandb_bbox returns the image, the predicted bounding box coordinates, and the ground truth coordinates in the required format. Note that the passed values have dtype which is JSON serializable. For example, if your pred_label should be float type and not ndarray.float.
def wandb_bbox(image, p_bbox, pred_label, t_bbox, true_label, class_id_to_label):return wandb.Image(image, boxes={"predictions": {"box_data": [{"position": {"minX": p_bbox[0],"maxX": p_bbox[2],"minY": p_bbox[1],"maxY": p_bbox[3]},"class_id" : pred_label,"box_caption": class_id_to_label[pred_label]}],"class_labels": class_id_to_label},"ground_truth": {"box_data": [{"position": {"minX": t_bbox[0],"maxX": t_bbox[2],"minY": t_bbox[1],"maxY": t_bbox[3]},"class_id" : true_label,"box_caption": class_id_to_label[true_label]}],"class_labels": class_id_to_label}})
Our BBoxLogger is a custom Keras callback. We can pass it to model.fit to log our model's predictions on a small validation set. Weights & Biases will automatically overlay the bounding box on the image.
class BBoxLogger(tf.keras.callbacks.Callback):def __init__(self):super(BBoxLogger, self).__init__()self.val_images, label_bbox = next(iter(testloader))self.true_labels = label_bbox['label']self.true_bbox = label_bbox['bbox']def on_epoch_end(self, logs, epoch):localization_list = []for idx in range(len(self.val_images)):# get imageimage = self.val_images[idx]# get ground truth label and bbox coordinates.true_label = int(self.true_labels[idx].numpy())t_bbox = self.true_bbox[idx]# get model prediction.pred_label, p_bbox = model.predict(np.expand_dims(image, 0))# get argmax of the predictionpred_label = int(np.argmax(pred_label[0]))# get wandb imagelocalization_list.append(wandb_bbox(image,p_bbox[0].tolist(),pred_label,t_bbox.numpy().tolist(),true_label,class_id_to_label))wandb.log({"predictions" : localization_list})
We will soon look at the results.
Results
Now on to the exciting part. I have trained the model with early stopping with the patience of 10 epochs. Feel free to train the model for longer epochs and play with other hyperparameters.
Since we have multiple losses associated with our task, we will have multiple metrics to log and monitor. Weights & Biases automatically log all the metrics using keras.WandbCallback callback.
Run set
1
The result of BBoxLogger is shown below. Click on the ⚙️ icon in the media panel below(Result of BBoxLogger) to check out the interaction controls. You can visualize both ground truth and predicted bounding boxes together or separately. You can even select the class which you don't want to visualize.
Observations
- The model is accurately classifying the images. This can be further confirmed by looking at the classification metrics shown above. For MNIST-like datasets, it is expected to have high accuracy.
- The prediction of the bounding box coordinates looks okayish. We should wait and admire the power of neural networks here. With just a few lines of code we are able to locate the digits.
Improvements
Few things that we can do to improve the bounding box prediction are:
- Increase the depth of the regression network of our model and train. It might lead to overfitting but it's worth a try.
- Train the current model. Freeze the convolutional layer and the classification network and train the regression network for a few more epochs.
Run set
1
Add a comment
Tags: Intermediate, Computer Vision, Object Detection, Keras, Experiment, Conv2D, Github, Panels, Plots, MNIST
Iterate on AI agents and models faster. Try Weights & Biases today.