Semantic Segmentation with UNets in PyTorch
Perform semantic segmentation of a given scene and record results in wandb tables.
Created on July 12|Last edited on July 12
Comment
The general idea of Object Detection is to detect given objects in an image and classify the detected object correctly to the object it belongs. For example you may have an image of vehicles and your model detects all vehicles and tells whether it is a car, bus, truck, etc.
So if we already have models for object detection what is the need of a new concept called as Semantic Segmentation?
Semantic segmentation is used to label each pixel of an image with a corresponding class of what is being represented. Because we're predicting for every pixel in the image, this task is commonly referred to as dense prediction.
There are a number of popular architectures for the task, today we shall dive into the UNet Architecture, and implement it as well.
1. UNET ARCHITECTURE1.1 Paths in UNet1.1.1 Contracting Path1.1.2 Expansive Path2. PYTORCH IMPLEMENTATION2.1 Preliminaries2.2 Configuration2.3 Transformations2.4 Dataset Class2.5 DataLoaders2.6 UNet Architecture2.7 Loss Function2.8 Engine2.9 Train3. Visualize Predictions using Wandb Tables4. CONCLUSION5. COLAB NOTEBOOKREFERENCESReach Out
1. UNET ARCHITECTURE

The above diagram has been taken from the original UNet Paper.
1.1 Paths in UNet
UNet consists of a contracting path and an expansive path.
1.1.1 Contracting Path
- The contracting path follows the typical architecture of a convolutional network.
- It consists of the repeated application of two 3x3 convolutions (unpadded convolutions), each followed by a rectified linear unit (ReLU) and a 2x2 max pooling operation with stride 2 for downsampling.
- At each downsampling step we double the number of feature channels.
1.1.2 Expansive Path
- Every step in the expansive path consists of an upsampling of the feature map followed by a 2x2 convolution (“up-convolution”) that halves the number of feature channels, a concatenation with the correspondingly cropped feature map from the contracting path, and two 3x3 convolutions, each followed by a ReLU.
- The cropping is necessary due to the loss of border pixels in every convolution.
- At the final layer a 1x1 convolution is used to map each 64-component feature vector to the desired number of classes.
In total the network has 23 convolutional layers.
In simple words there is a combination of upsampling and downsampling in a UNet. It is called a "U" Net because of its U like structure.
2. PYTORCH IMPLEMENTATION
Let us implement UNet from scratch for a Semantic Segmentation task and prepare a pipeline which can be utilized for similar datasets. All steps have been explained in detail to help you understand in the easiest way.
2.1 Preliminaries
We begin by installing and importing all the necessities.
!pip install wandbimport osif not os.path.exists('dataset1'):!wget -q https://www.dropbox.com/s/0pigmmmynbf9xwq/dataset1.zip!unzip -q dataset1.zip!rm dataset1.zip!pip install -q torch_snippets pytorch_model_summaryfrom torch_snippets import *from torchvision import transformsfrom sklearn.model_selection import train_test_splitfrom torchvision.models import vgg16_bnfrom tqdm import tqdm
The we login to our wandb account.
# Wandb Loginimport wandbwandb.login()
2.2 Configuration
Then we set the configuration of the project as well as wandb. We initialize wandb as well.
class config:DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'LEARNING_RATE = 1e-3N_EPOCHS = 20# wandb configWANDB_CONFIG = {'_wandb_kernel': 'neuracort'}# Initialize W&Brun = wandb.init(project='semantic_segmentation_unet',config= WANDB_CONFIG)
2.3 Transformations
To keep things simple, we only use basic transforms and define the function get_transforms()
def get_transforms():return transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])
We simply convert the input image to a tensor and normalize with image net values.
2.4 Dataset Class
Like any regular PyTorch project we define the Dataset Class.
class SegmentationData(Dataset):def __init__(self, split):self.items = stems(f'dataset1/images_prepped_{split}')self.split = splitdef __len__(self):return len(self.items)def __getitem__(self, ix):image = read(f'dataset1/images_prepped_{self.split}/{self.items[ix]}.png', 1)image = cv2.resize(image, (224,224))mask = read(f'dataset1/annotations_prepped_{self.split}/{self.items[ix]}.png')mask = cv2.resize(mask, (224,224))return image, maskdef choose(self): return self[randint(len(self))]def collate_fn(self, batch):ims, masks = list(zip(*batch))ims = torch.cat([get_transforms()(im.copy()/255.)[None] for im in ims]).float().to(config.DEVICE)ce_masks = torch.cat([torch.Tensor(mask[None]) for mask in masks]).long().to(config.DEVICE)return ims, ce_masks
- __init__ specifies the image location.
- __len__ specifies the length of the dataset.
- __getitem__ loads an image and mask and resizes them to the same size.
- choose selects a random image index for debugging.
- collate_fn performs preprocessing on a batch of images.
2.5 DataLoaders
Here we create a function to initialise the PyTorch dataset and create the dataloaders.
def get_dataloaders():trn_ds = SegmentationData('train')val_ds = SegmentationData('test')trn_dl = DataLoader(trn_ds, batch_size=4, shuffle=True, collate_fn=trn_ds.collate_fn)val_dl = DataLoader(val_ds, batch_size=1, shuffle=True, collate_fn=val_ds.collate_fn)return trn_dl, val_dltrn_dl, val_dl = get_dataloaders()
We first initialise the train and validation sets followed by obtaining the dataloaders.
2.6 UNet Architecture
First we define the convolution block.
def conv(in_channels, out_channels):return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))
It is followed by the up convolutional block.
def up_conv(in_channels, out_channels):return nn.Sequential(nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),nn.ReLU(inplace=True))
Now that we have the building blocks, we leverage them to create the entire architecture.
class UNet(nn.Module):def __init__(self, pretrained=True, out_channels=12):super().__init__()self.encoder = vgg16_bn(pretrained=pretrained).featuresself.block1 = nn.Sequential(*self.encoder[:6])self.block2 = nn.Sequential(*self.encoder[6:13])self.block3 = nn.Sequential(*self.encoder[13:20])self.block4 = nn.Sequential(*self.encoder[20:27])self.block5 = nn.Sequential(*self.encoder[27:34])self.bottleneck = nn.Sequential(*self.encoder[34:])self.conv_bottleneck = conv(512, 1024)self.up_conv6 = up_conv(1024, 512)self.conv6 = conv(512 + 512, 512)self.up_conv7 = up_conv(512, 256)self.conv7 = conv(256 + 512, 256)self.up_conv8 = up_conv(256, 128)self.conv8 = conv(128 + 256, 128)self.up_conv9 = up_conv(128, 64)self.conv9 = conv(64 + 128, 64)self.up_conv10 = up_conv(64, 32)self.conv10 = conv(32 + 64, 32)self.conv11 = nn.Conv2d(32, out_channels, kernel_size=1)def forward(self, x):block1 = self.block1(x)block2 = self.block2(block1)block3 = self.block3(block2)block4 = self.block4(block3)block5 = self.block5(block4)bottleneck = self.bottleneck(block5)x = self.conv_bottleneck(bottleneck)x = self.up_conv6(x)x = torch.cat([x, block5], dim=1)x = self.conv6(x)x = self.up_conv7(x)x = torch.cat([x, block4], dim=1)x = self.conv7(x)x = self.up_conv8(x)x = torch.cat([x, block3], dim=1)x = self.conv8(x)x = self.up_conv9(x)x = torch.cat([x, block2], dim=1)x = self.conv9(x)x = self.up_conv10(x)x = torch.cat([x, block1], dim=1)x = self.conv10(x)x = self.conv11(x)return x
You can refer to the architecture diagram to understand how the above layers are interconnected.
Can you notice the usage of torch.cat? Can you relate it to the UNet Architecture?
Here we are using a VGG16 with Batch Normalization model as the encoder layer of the architecture. We obtain several blocks from this encoder layer. Following this, we define the convolution and up convolution blocks.
The forward function can be best understood with reference to the architecture as to how the blocks are interconnected.
2.7 Loss Function
We consider Cross Entropy Loss as a loss function because we have more than 2 classes. We can define a custom function to provide the loss and accuracies.
ce = nn.CrossEntropyLoss()def UnetLoss(preds, targets):ce_loss = ce(preds, targets)acc = (torch.max(preds, 1)[1] == targets).float().mean()return ce_loss, acc
2.8 Engine
The engine consists of the training and validation loop function. These functions are pretty standard to PyTorch where you set the mode, obtain the outputs, calculate loss and update the weights!
class engine():def train_batch(model, data, optimizer, criterion):model.train()ims, ce_masks = data_masks = model(ims)optimizer.zero_grad()loss, acc = criterion(_masks, ce_masks)loss.backward()optimizer.step()return loss.item(), acc.item()@torch.no_grad()def validate_batch(model, data, criterion):model.eval()ims, masks = data_masks = model(ims)loss, acc = criterion(_masks, masks)return loss.item(), acc.item()
We also define a function make_model() which provides the model, criterion and the optimizer.
def make_model():model = UNet().to(config.DEVICE)criterion = UnetLossoptimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE)return model, criterion, optimizermodel, criterion, optimizer = make_model()
2.9 Train
For training we first iterate over the training dataloader and obtain the loss and accuracy followed by the validation dataloader.
All the metrics can be logged using wandb as shown.
def run():for epoch in range(config.N_EPOCHS):print("####################")print(f" Epoch: {epoch} ")print("####################")for bx, data in tqdm(enumerate(trn_dl), total = len(trn_dl)):train_loss, train_acc = engine.train_batch(model, data, optimizer, criterion)for bx, data in tqdm(enumerate(val_dl), total = len(val_dl)):val_loss, val_acc = engine.validate_batch(model, data, criterion)wandb.log({'epoch': epoch,'train_loss': train_loss,'train_acc': train_acc,'val_loss': val_loss,'val_acc': val_acc})print()
And we finally run the code!
run()
Because of less number of images the model trains in a couple of minutes.
Run set
1
As it is observable from the charts the accuracies steadily increase and the losses decrease. Best validation accuracy is obtained at 91.52% which is great considering the size of the dataset.
3. Visualize Predictions using Wandb Tables
Wandb tables help you record the model predictions in the best way possible. First we declare a table using the wandb.Table() command and specify the column names.
Then we iterate over the validation dataset and add the images to the table using the table.add_data() command. As we are logging images we use wandb.Image()
When compelete data has been added to the table we log it to wandb using wandb.log() command.
def save_table(table_name):table = wandb.Table(columns=['Original Image', 'Original Mask', 'Predicted Mask'], allow_mixed_types = True)for bx, data in tqdm(enumerate(val_dl), total = len(val_dl)):im, mask = data_mask = model(im)_, _mask = torch.max(_mask, dim=1)plt.figure(figsize=(10,10))plt.axis("off")plt.imshow(im[0].permute(1,2,0).detach().cpu()[:,:,0])plt.savefig("original_image.jpg")plt.close()plt.figure(figsize=(10,10))plt.axis("off")plt.imshow(mask.permute(1,2,0).detach().cpu()[:,:,0])plt.savefig("original_mask.jpg")plt.close()plt.figure(figsize=(10,10))plt.axis("off")plt.imshow(_mask.permute(1,2,0).detach().cpu()[:,:,0])plt.savefig("predicted_mask.jpg")plt.close()table.add_data(wandb.Image(cv2.cvtColor(cv2.imread("original_image.jpg"), cv2.COLOR_BGR2RGB)),wandb.Image(cv2.cvtColor(cv2.imread("original_mask.jpg"), cv2.COLOR_BGR2RGB)),wandb.Image(cv2.cvtColor(cv2.imread("predicted_mask.jpg"), cv2.COLOR_BGR2RGB)))wandb.log({table_name: table})save_table("Predictions")
Run set
1
You can see how the results have been neatly logged to the table. The original image, original mask as well as the predicted masks are available and using wandb tables you can easily visualize the results and keep a permanent record of them!
4. CONCLUSION
Thus we learnt about the UNet Architecture in brief and created a Semantic Segmentation model with a good accuracy.
We also learnt how to log metrics and results to wandb tables. Wandb tables serve as an interactive tool to visualize the predictions.
5. COLAB NOTEBOOK
To try things first hand you can utilize this colab notebook and try changing the parameters and observe the changes. Maybe you can create a report and share your findings with me
REFERENCES
Reach Out
You can reach out to me on LinkedIn or Twitter.
Add a comment