Skip to main content

Image Segmentation Using Keras and Weights & Biases

This article explores semantic segmentation with a UNET-like architecture in Keras and interactively visualizes the model's prediction using Weights & Biases.
Created on September 23|Last edited on March 9
Are you interested to know where an object is in the image? What is the shape of the object? Which pixels belong to the object? To accomplish this, we need to segment the image, i.e., classify each pixel of the image to the object it belongs to or give each pixel of the image a label contrary to giving one label to an image.

Try out semantic segmentation on Google Colab \rightarrow

Thus, image segmentation is the task of learning a pixel-wise mask for each object in the image. Unlike object detection, which gives the bounding box coordinates for each object present in the image, image segmentation gives a far more granular understanding of the object(s) in the image.
Figure 1: Semantic segmentation and Instance segmentation
Image segmentation can be broadly divided into two types:
  • Semantic segmentation - Here, each pixel belongs to a particular class. The left image in figure 1 is an example of semantic segmentation. The pixels either belong to the person(a class) or background(another class).
  • Instance segmentation - Here, each pixel belongs to a particular class. However, pixels belonging to discrete objects are labeled with a different color(mask value). The right image in figure 1 is an example of instance, segmentation. The pixels belonging to the person's class are colored differently.
This article will build a semantic segmentation model and train it on Oxford-IIIT Pet Dataset. We will interactively visualize our models' predictions in Weights & Biases.

Table of Contents



The Dataset

We will use Oxford-IIIT Pet Dataset to train our UNET-like semantic segmentation model.
The dataset consists of images and their pixel-wise mask. The pixel-wise masks are labels for each pixel.
  • Class 1: Pixels belonging to the pet.
  • Class 2: Pixels belonging to the outline of the pet.
  • Class 3: Pixels belonging to the background.
We will be getting the dataset from TensorFlow Datasets catalogue which makes it easier to download the data and use it with tf.data data pipeline. The dataset consists of 7349 images of which 3,680 are in the training set and the remaining in the test set.
train_ds, valid_ds = tfds.load('oxford_iiit_pet', split=["train", "test"])
We will then build the training and validation (test) data loaders using tf.data.Dataset API.
The parse_data function shown below, scales (0,1) and resizes the image. It also resizes the associated segmentation mask and one hot encode it. Since the ground truth mask labels are [1, 2, 3], the mask is subtracted by 1 to bring it to [0, 1, 2].
While resizing mask, use the "nearest" method for interpolation. Using the default interpolation method ("bilinear") will result in erroneous mask.
💡
AUTOTUNE = tf.data.experimental.AUTOTUNE

def parse_data(example):
# Parse image
image = example["image"]
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, size=(configs.img_size, configs.img_size))

# Parse mask
mask = example["segmentation_mask"] - 1 # ground truth labels are [1,2,3].
mask = tf.image.resize(mask, size=(configs.img_size, configs.img_size), method='nearest')
mask = tf.one_hot(tf.squeeze(mask, axis=-1), depth=configs.num_classes)

return image, mask

trainloader = (
train_ds
.shuffle(1024)
.map(parse_data, num_parallel_calls=AUTOTUNE)
.batch(configs.batch_size)
.prefetch(AUTOTUNE)
)

validloader = (
valid_ds
.map(parse_data, num_parallel_calls=AUTOTUNE)
.batch(configs.batch_size)
.prefetch(AUTOTUNE)
)
Let's interact with the dataset below:

Run set
1


The Model

The model being used here is vanilla UNET architecture. It consists of an encoder and a decoder network. The input to this architecture is the image, while the output is the pixel-wise map. You can learn more about the encoder-decoder(Autoencoder) network in Towards Deep Generative Modeling with W&B report.
The UNET-like architecture is commonly found in self-supervised deep learning tasks like Image Inpainting.
You can learn more about UNET architecture in this Line by Line Explanation.
Figure 3: A typical UNET architecture. (Source)

The code snippet shown below builds our model architecture for semantic segmentation.
# ref: https://github.com/ayulockin/deepimageinpainting/blob/master/Image_Inpainting_Autoencoder_Decoder_v2_0.ipynb
class SegmentationModel:
'''
Build UNET based model for segmentation task.
'''
def prepare_model(self, OUTPUT_CHANNEL, input_size=(configs.img_size, configs.img_size, 3)):
inputs = layers.Input(input_size)

conv1, pool1 = self.__ConvBlock(32, (3,3), (2,2), 'relu', 'same', inputs)
conv2, pool2 = self.__ConvBlock(64, (3,3), (2,2), 'relu', 'same', pool1)
conv3, pool3 = self.__ConvBlock(128, (3,3), (2,2), 'relu', 'same', pool2)
conv4, pool4 = self.__ConvBlock(256, (3,3), (2,2), 'relu', 'same', pool3)
conv5, up6 = self.__UpConvBlock(512, 256, (3,3), (2,2), (2,2), 'relu', 'same', pool4, conv4)
conv6, up7 = self.__UpConvBlock(256, 128, (3,3), (2,2), (2,2), 'relu', 'same', up6, conv3)
conv7, up8 = self.__UpConvBlock(128, 64, (3,3), (2,2), (2,2), 'relu', 'same', up7, conv2)
conv8, up9 = self.__UpConvBlock(64, 32, (3,3), (2,2), (2,2), 'relu', 'same', up8, conv1)

conv9 = self.__ConvBlock(32, (3,3), (2,2), 'relu', 'same', up9, False)
outputs = layers.Conv2D(OUTPUT_CHANNEL, (3, 3), activation='softmax', padding='same')(conv9)

return models.Model(inputs=[inputs], outputs=[outputs])

def __ConvBlock(self, filters, kernel_size, pool_size, activation, padding, connecting_layer, pool_layer=True):
conv = layers.Conv2D(filters=filters, kernel_size=kernel_size, activation=activation, padding=padding)(connecting_layer)
conv = layers.Conv2D(filters=filters, kernel_size=kernel_size, activation=activation, padding=padding)(conv)
if pool_layer:
pool = layers.MaxPooling2D(pool_size)(conv)
return conv, pool
else:
return conv

def __UpConvBlock(self, filters, up_filters, kernel_size, up_kernel, up_stride, activation, padding, connecting_layer, shared_layer):
conv = layers.Conv2D(filters=filters, kernel_size=kernel_size, activation=activation, padding=padding)(connecting_layer)
conv = layers.Conv2D(filters=filters, kernel_size=kernel_size, activation=activation, padding=padding)(conv)
up = layers.Conv2DTranspose(filters=up_filters, kernel_size=up_kernel, strides=up_stride, padding=padding)(conv)
up = layers.concatenate([up, shared_layer], axis=3)

return conv, up
Notice that the OUTPUT_CHANNEL is 3 for our dataset. This is because there are three classes of pixels, as described in the dataset section. Consider that we are doing multi-class classification wherein each pixel can belong to either of the three classes.
Also, note that since it is a multi-class classification problem per pixel, the output activation function is softmax.
OUTPUT_CHANNEL = 3

model = SegmentationModel().prepare_model(OUTPUT_CHANNEL)
model.compile(optimizer="adam", loss="categorical_crossentropy")
Finally, the model is compiled with categorical_crossentropy. If we don't one hot encode the mask, we can use sparse_categorical_crossentropy.

The WandbSemanticLogger Callback - Interactive Visualization of Predictions

When working on semantic segmentation, you can interactively visualize your models' predictions in Weights & Biases. If you have images with masks for semantic segmentation, you can log the masks and toggle them on and off in the UI. Check out the official documentation here.
The report Image Masks for Semantic Segmentation by Stacey Svetlichnaya will walk you through the interactive controls for this tool. It covers the various nuisances of logging images and masks.
We will build the logger by inheriting from WandbEvalCallback which is an abstract base class to build Keras callbacks primarily for model prediction and, secondarily, dataset visualization. In order to build our own segmentation logger, we just need to implement add_ground_truth and add_model_predictions methods as shown below.
class WandbSemanticLogger(WandbEvalCallback):
def __init__(
self,
validloader,
data_table_columns=["index", "image"],
pred_table_columns=["epoch", "index", "image", "prediction"],
num_samples=100,
):
super().__init__(
data_table_columns,
pred_table_columns,
)
# Unpack the `validloader` and take samples to log.
self.val_data = validloader.unbatch().take(num_samples)

def add_ground_truth(self, logs):
# Iterate though the dataset and add them to `data_table`.
for idx, (image, mask) in enumerate(self.val_data):
self.data_table.add_data(
idx,
self._prepare_wandb_mask(
image.numpy(),
np.argmax(mask.numpy(), axis=-1),
"ground_truth"
)
)

def add_model_predictions(self, epoch, logs):
# Get reference to the `data_table`
data_table_ref = self.data_table_ref
table_idxs = data_table_ref.get_index()

# Iterate through the dataset, get prediction and add them to `pred_table`.
for idx, (image, mask) in enumerate(self.val_data):
prediction = self.model.predict(tf.expand_dims(image, axis=0), verbose=0)
prediction = np.argmax(tf.squeeze(prediction, axis=0).numpy(), axis=-1)

self.pred_table.add_data(
epoch,
data_table_ref.data[idx][0],
self._prepare_wandb_mask(
data_table_ref.data[idx][1],
np.argmax(mask.numpy(), axis=-1),
"ground_truth"
),
self._prepare_wandb_mask(
data_table_ref.data[idx][1],
prediction,
"prediction"
)
)

def _prepare_wandb_mask(self, image, mask, mask_type):
return wandb.Image(
image,
masks = {
"ground_truth": {
"mask_data": mask,
"class_labels": labels()
}})
Under the hood, we log the data_table (an instance of W&B Tables) to W&B when the on_train_begin method is invoked. Once it's uploaded as a W&B Artifact, we get a "reference to this table" which can be accessed using data_table_ref class variable. The data_table_ref is a 2D list that can be indexed like self.data_table_ref[idx][n], where idx is the row number while n is the column number. Let's see the usage in the example below.

Results

Now on to the exciting part. I have trained the model for 10 epochs. The loss and validation loss metrics are shown in the chart below. Feel free to train the model for longer epochs and play with other hyperparameters.

Run set
1

The result of WandbSemanticLogger is shown below. Click on the ⚙️ icon in the media panel below(Result of SemanticLogger) to check out interaction controls. You can visualize images and masks separately and can choose which semantic class to visualize.

Observations

  • The model learns to predict the pet and background class well.
  • We can see that the model is having a hard time segmenting pet_outline class. This is because of the high class imbalance, and the model is not regularized to counter this imbalance.

Run set
1


Conclusion and Final Thoughts

I hope you enjoyed this report on Semantic Segmentation. The intention of this report was two folds:
  • Make the semantic segmentation technique more accessible to interested folks.
  • Show how Weights & Biases can help interactively visualize models' predictions and metrics. Moreover, show the observations one can derive from these visualizations.
I would love to get your feedback in the comment section. 😄

William Green
William Green •  
This is great!
1 reply
Aritra Roy Gosthipaty
Aritra Roy Gosthipaty •  
Great report! Would definitely try the logger.
1 reply
Iterate on AI agents and models faster. Try Weights & Biases today.