Image Segmentation Using Keras and W&B

This report explores semantic segmentation with a UNET like architecture in Keras and interactively visualizes the model's prediction in Weights & Biases.
Ayush Thakur


Are you interested to know where an object is in the image? What is the shape of the object? Which pixels belong to the object? To accomplish this, we need to segment the image, i.e., classify each pixel of the image to the object it belongs to or give each pixel of the image a label contrary to giving one label to an image.

Try out semantic segmentation on Google Colab $\rightarrow$

Thus, image segmentation is the task of learning a pixel-wise mask for each object in the image. Unlike object detection, which gives the bounding box coordinates for each object present in the image, image segmentation gives a far more granular understanding of the object(s) in the image.


-> Figure 1: Semantic segmentation and Instance segmentation <-

Image segmentation can be broadly divided into two types:

This report will build a semantic segmentation model and train it on Oxford-IIIT Pet Dataset. We will interactively visualize our models' predictions in Weights & Biases.

The Dataset

We will use Oxford-IIIT Pet Dataset to train our UNET-like semantic segmentation model.

The dataset consists of images and their pixel-wise mask. The pixel-wise masks are labels for each pixel.


-> Figure 2: Pets and their pixel-wise masks. <-

Download the dataset

!curl -O
!curl -O
!tar -xf images.tar.gz
!tar -xf annotations.tar.gz

Dataset preparation

The images/ and annotations/trimaps directories contain extracted images and their annotations(pixel-wise masks). The required images are in .jpg format while the annotations are in .png format. However, there are files in those directories which are not required for our purpose. We will thus prepare two lists - input_img_paths and annotation_img_paths which contains the paths to required images and annotations.

IMG_PATH = 'images/'
ANNOTATION_PATH = 'annotations/trimaps/'

input_img_paths = sorted(
        os.path.join(IMG_PATH, fname)
        for fname in os.listdir(IMG_PATH)
        if fname.endswith(".jpg")
annotation_img_paths = sorted(
        os.path.join(ANNOTATION_PATH, fname)
        for fname in os.listdir(ANNOTATION_PATH)
        if fname.endswith(".png") and not fname.startswith(".")

print(len(input_img_paths), len(annotation_img_paths))

There are a total of 7390 images and annotations. We shall use 1000 images and their annotations as the validation set.

Dataloader using

We will use to build our input pipeline.


def scale_down(image, mask):
  # apply scaling to image and mask
  image = tf.cast(image, tf.float32) / 255.0
  mask -= 1
  return image, mask

def load_and_preprocess(img_filepath, mask_filepath):
   # load the image and resize it
    img =
    img =, channels=3)
    img = tf.image.resize(img, [IMG_SHAPE, IMG_SHAPE])

    mask =
    mask =, channels=1)
    mask = tf.image.resize(mask, [IMG_SHAPE, IMG_SHAPE])

    img, mask = scale_down(img, mask)

    return img, mask

# shuffle the paths and prepare train-test split
input_img_paths, annotation_img_paths = shuffle(input_img_paths, annotation_img_paths, random_state=42)
input_img_paths_train, annotation_img_paths_train = input_img_paths[: -1000], annotation_img_paths[: -1000]
input_img_paths_test, annotation_img_paths_test = input_img_paths[-1000:], annotation_img_paths[-1000:]

trainloader =, annotation_img_paths_train))
testloader =, annotation_img_paths_test))

trainloader = (
    .map(load_and_preprocess, num_parallel_calls=AUTO)

testloader = (
    .map(load_and_preprocess, num_parallel_calls=AUTO)

The Model

The model being used here is vanilla UNET architecture. It consists of an encoder and a decoder network. The input to this architecture is the image, while the output is the pixel-wise map. You can learn more about the encoder-decoder(Autoencoder) network in Towards Deep Generative Modeling with W&B report.

The UNET-like architecture is commonly found in self-supervised deep learning tasks like Image Inpainting.

You can learn more about UNET architecture in this Line by Line Explanation. 1 f7YOaE4TWubwaFF7Z1fzNw.png

-> Figure 3: A typical UNET architecture. (Source) <-

The code snippet shown below builds our model architecture for semantic segmentation.

class SegmentationModel:
  Build UNET like model for image inpaining task.
  def prepare_model(self, OUTPUT_CHANNEL, input_size=(IMG_SHAPE,IMG_SHAPE,3)):
    inputs = Input(input_size)

    # Encoder 
    conv1, pool1 = self.__ConvBlock(32, (3,3), (2,2), 'relu', 'same', inputs) 
    conv2, pool2 = self.__ConvBlock(64, (3,3), (2,2), 'relu', 'same', pool1)
    conv3, pool3 = self.__ConvBlock(128, (3,3), (2,2), 'relu', 'same', pool2) 
    conv4, pool4 = self.__ConvBlock(256, (3,3), (2,2), 'relu', 'same', pool3) 
    # Decoder
    conv5, up6 = self.__UpConvBlock(512, 256, (3,3), (2,2), (2,2), 'relu', 'same', pool4, conv4)
    conv6, up7 = self.__UpConvBlock(256, 128, (3,3), (2,2), (2,2), 'relu', 'same', up6, conv3)
    conv7, up8 = self.__UpConvBlock(128, 64, (3,3), (2,2), (2,2), 'relu', 'same', up7, conv2)
    conv8, up9 = self.__UpConvBlock(64, 32, (3,3), (2,2), (2,2), 'relu', 'same', up8, conv1)
    conv9 = self.__ConvBlock(32, (3,3), (2,2), 'relu', 'same', up9, False)
    # Notice OUTPUT_CHANNEL and activation
    outputs = Conv2D(OUTPUT_CHANNEL, (3, 3), activation='softmax', padding='same')(conv9)

    return Model(inputs=[inputs], outputs=[outputs])  

  def __ConvBlock(self, filters, kernel_size, pool_size, activation, padding, connecting_layer, pool_layer=True):
    conv = Conv2D(filters=filters, kernel_size=kernel_size, activation=activation, padding=padding)(connecting_layer)
    conv = Conv2D(filters=filters, kernel_size=kernel_size, activation=activation, padding=padding)(conv)
    if pool_layer:
      pool = MaxPooling2D(pool_size)(conv)
      return conv, pool
      return conv

  def __UpConvBlock(self, filters, up_filters, kernel_size, up_kernel, up_stride, activation, padding, connecting_layer, shared_layer):
    conv = Conv2D(filters=filters, kernel_size=kernel_size, activation=activation, padding=padding)(connecting_layer)
    conv = Conv2D(filters=filters, kernel_size=kernel_size, activation=activation, padding=padding)(conv)
    up = Conv2DTranspose(filters=up_filters, kernel_size=up_kernel, strides=up_stride, padding=padding)(conv)
    up = concatenate([up, shared_layer], axis=3)

    return conv, up

Notice that the OUTPUT_CHANNEL is 3 for our dataset. This is because there are three classes of pixels, as described in the dataset section. Consider that we are doing multi-class classification wherein each pixel can belong to either of the three classes.

Also, note that since it is a multi-class classification problem per pixel, the output activation function is softmax.


model = SegmentationModel().prepare_model(OUTPUT_CHANNEL)
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy")

Finally, the model is compiled with sparse_categorical_crossentropy. Sparse since the pixel-wise mask/annotation is in integer.

The SemanticLogger Callback - Interactive Visualization of Predictions

When working on semantic segmentation, you can interactively visualize your models' predictions in Weights & Biases. If you have images with masks for semantic segmentation, you can log the masks and toggle them on and off in the UI. Check out the official documentation here.

The report Image Masks for Semantic Segmentation by Stacey Svetlichnaya will walk you through the interactive controls for this tool. It covers the various nuisances of logging images and masks.

The code snippets shown below are the helper functions for our SemanticLogger callback. The function labels returns a dictionary where the key is the class value, and the value is the label. The function wandb_mask returns the image, the prediction mask, and the ground truth mask in the required format.

segmentation_classes = ['pet', 'pet_outline', 'background']

# returns a dictionary of labels
def labels():
  l = {}
  for i, label in enumerate(segmentation_classes):
    l[i] = label
  return l

# util function for generating interactive image mask from components
def wandb_mask(bg_img, pred_mask, true_mask):
  return wandb.Image(bg_img, masks={
      "prediction" : {
          "mask_data" : pred_mask, 
          "class_labels" : labels()
      "ground truth" : {
          "mask_data" : true_mask, 
          "class_labels" : labels()

Our SemanticLogger is a custom Keras callback. We can pass it to to log our model's predictions on a small validation set. Weights and Biases will automatically overlay the mask on the image.

class SemanticLogger(tf.keras.callbacks.Callback):
    def __init__(self):
        super(SemanticLogger, self).__init__()
        self.val_images, self.val_masks = next(iter(testloader))

    def on_epoch_end(self, logs, epoch):
        pred_masks = self.model.predict(self.val_images)
        pred_masks = np.argmax(pred_masks, axis=-1)
        # pred_masks = np.expand_dims(pred_masks, axis=-1)

        val_images = tf.image.convert_image_dtype(self.val_images, tf.uint8)
        val_masks = tf.image.convert_image_dtype(self.val_masks, tf.uint8)
        val_masks = tf.squeeze(val_masks, axis=-1)
        pred_masks = tf.image.convert_image_dtype(pred_masks, tf.uint8)

        mask_list = []
        for i in range(len(self.val_images)):

        wandb.log({"predictions" : mask_list})

We will shortly look at the results.


Now on to the exciting part. I have trained the model for 15 epochs. The loss and validation loss metrics are shown in the chart below. Feel free to train the model for longer epochs and play with other hyperparameters.

Section 9

The result of SemanticLogger is shown below. Click on the :gear: icon in the media panel below(Result of SemanticLogger) to check out interaction controls. You can visualize images and masks separately and can choose which semantic class to visualize.


Section 10

Conclusion and Final Thoughts

I hope you enjoyed this report on Semantic Segmentation. The intention of this report was two folds:

On an ending note, here are some resources that might be a good read:

I would love to get your feedback in the comment section. :smile: