Skip to main content

Multi Task Learning in PyTorch

Implement an Age and Gender Classifier using a single model
Created on February 10|Last edited on February 19
A number of Deep Learning projects are focused on accomplishing a single desired task using a model. Like an animal classifier model has one task of classifying all animals correctly.
But in a lot of instances we might have several similar tasks which need to be predicted.
Take the example of an autonomous car. Given the input live feed, the car needs to perform several tasks like lane detection, sign board detection, obstacle detection and so on. You have two choices here, either to create separate training pipelines for all tasks or combine them together in a single pipeline as all these tasks are similar and have same input (i.e., the live feed from the sensors of the car).
This is where Multi Task learning comes into picture.
Multi-Task Learning is a branch of research where a single/few inputs are used to predict several different but ultimately connected outputs

With the help of this article, I will demonstrate how you can perform Age and Gender Classification on the same dataset in a single training pipeline.
We will use a dataset containing important features such as human images, age, gender and so on.
Let's start directly with the code!

Contents



Age and Gender Classification using Deep Learning

We begin by importing the necessary libraries that will be required.

Preliminaries

# Necessities
import os
import cv2
import glob
import time
import random
import pandas as pd
import numpy as np

# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import datasets, models, transforms

# Data Visualization
import matplotlib.pyplot as plt
%matplotlib inline

# Progress Bar
from tqdm.notebook import tqdm

# File download
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

# wandb
!pip install wandb
import wandb
wandb.login()



Global Configuration

Next we define the Global Configuration which incorporates the file paths and hyperparameters.
class config:
DIRECTORY_PATH = '/content'
TRAIN_FILE_PATH = os.path.join(DIRECTORY_PATH, 'fairface-label-train.csv')
VAL_FILE_PATH = os.path.join(DIRECTORY_PATH, 'fairface-label-val.csv')
MODEL_PATH = os.path.join(DIRECTORY_PATH, '/models/model.bin')

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SEED = 42
NUM_EPOCHS = 5
IMAGE_SIZE = 224
BATCH_SIZE = 32

LEARNING_RATE = 1e-4
PRETRAINED = True
To be able to reproduce the results, it is very important to seed everything which is done using the below function.
def seed_everything(SEED = config.SEED):
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
os.environ['PYTHONHASHSEED']=str(SEED)



Download Dataset

Then, we will download the dataset from google drive using the authentication method as shown.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

def getFile_from_drive( file_id, name ):
downloaded = drive.CreateFile({'id': file_id})
downloaded.GetContentFile(name)

getFile_from_drive('1Z1RqRo0_JiavaZw2yzZG6WETdZQ8qX86', 'fairface-img-margin025-trainval.zip')
getFile_from_drive('1k5vvyREmHDW5TSM9QgB04Bvc8C8_7dl-', 'fairface-label-train.csv')
getFile_from_drive('1_rtz1M1zhvS0d5vVoXUamnohB6cJ02iJ', 'fairface-label-val.csv')

!unzip -qq fairface-img-margin025-trainval.zip



Load Dataset

The dataset has been downloaded and now we can create the training and validation dataframes.
train_df = pd.read_csv(config.TRAIN_FILE_PATH)
val_df = pd.read_csv(config.VAL_FILE_PATH)

train_df.head()



From the above dataset we will be using only the file, age and gender features for training and validation.
Note: Whenever working with continuous features, scaling becomes important. Scaling is simply achieved by dividing a feature with its maximum value. Thus in this case to scales the images we divide them by 255 (maximum pixel value) and to scale the age feature we divide it by 79 (maximum age value).
# Maximum Value of Age
max(train_df.age) # prints 79



Dataset Class

We move ahead to creating the PyTorch Dataset Class. Generally the dataset class uses three important functions which are -
  1. __init__() - To initialize the Dataset module
  2. __len__() - To provide the number of samples
  3. __getitem__() - To obtain one sample at a time
But in this case, we preprocess the image inside the dataset class itself and collate i.e., combine as well. To achieve this, we use two more functions which are -
4. preprocess_image() - To resize, convert to tensors, permute, normalize and scale the images
5. collate_fn() - To combine the preprocessed features all together, scale the features and obtain tensors.
In the preprocess_image() function, we perform the following preprocessing -
  • Resize all images to the size 224 x 224
  • Convert images to PyTorch Tensors
  • Permute Images to obtain the dimensions of the PyTorch accepted form of channels x height x width
  • Scale the images by dividing them by the highest pixel value of 255
  • Normalize the image to the ImageNet format as we are using a pretrained model
class GenderAgeDataset(Dataset):
def __init__(self, df, transform = None):
self.df = df
self.normalize = transforms.Normalize(
mean = [0.485, 0.456, 0.406],
std = [0.229, 0.224, 0.225]
)

def __len__(self):
return len(self.df)

def __getitem__(self, index):
data = self.df.iloc[index].squeeze()
file = data.file
age = data.age
gender = data.gender == "Female"
img = cv2.imread(file)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img, age, gender

def preprocess_image(self, img):
img = cv2.resize(img, (config.IMAGE_SIZE, config.IMAGE_SIZE))
img = torch.tensor(img).permute(2,0,1)
img = self.normalize(img/255.)
return img[None]

def collate_fn(self, batch):
imgs, ages, genders = [], [], []

for img, age, gender in batch:
img = self.preprocess_image(img)
imgs.append(img)

age = float((int(age)/79)) # Scaling the age with maximum value
ages.append(age)

gender = float(gender)
genders.append(gender)

ages, genders = [torch.tensor(x).to(config.DEVICE).float() for x in [ages, genders]]

imgs = torch.cat(imgs).to(config.DEVICE)

return imgs, ages, genders



Create Datasets and DataLoaders

We define a function get_data() which defines the datasets and provides the training and validation dataloaders.
def get_data():
"""
Function to obtain the Training and Validation DataLoaders

Returns:
train_loader: Training DataLoader
val_loader: Validation DataLoader

"""

train = GenderAgeDataset(train_df)
val = GenderAgeDataset(val_df)

train_loader = DataLoader(
train,
batch_size = config.BATCH_SIZE,
shuffle = True,
drop_last = True,
collate_fn = train.collate_fn
)

val_loader = DataLoader(
val,
batch_size = config.BATCH_SIZE,
shuffle = False,
collate_fn = val.collate_fn
)

return train_loader, val_loader

train_loader, val_loader = get_data()

# Test the DataLoader
img, age, gender = next(iter(train_loader))
print(img.shape, age.shape, gender.shape)

# prints torch.Size([32, 3, 224, 224]) torch.Size([32]) torch.Size([32])
In the above output:
  • torch.Size([32, 3, 224, 224]) signifies the image in the form batch_size x channels x height x width
  • The two torch.Size([32]) signifies the unique values of age and gender for all the images.



Model

To implement this task, we will use the pretrained version of the VGG16 model.
model = models.vgg16(pretrained=True)
model
VGG16 Model Architecture
Note that in the VGG16 model architecture there are three key sub-modules in the model, namely - features, avgpool and classifier. For our work, we will freeze the features sub-module and override the avgpool and classifier sub-modules with custom layers.
Freezing implies that the model parameters (weights and biases) for the frozen layers will not be updated as we want to preserve the learning model has gained while pre-training. This is achieved by setting the param.requires_grad = False
To build the model, we consider the following points -
  • All the layers under the features sub-module remain the same.
  • We override the avgpool sub-module with a custom convolutional layer.
  • We override the classifier sub-module by creating custom layers in the GenderAgeClassifier class.
  • In the classifier, we create two separate layers (age_classifier and gender_classifier) branching out from the intermediate layer.
  • Note that we use two different loss functions as we are performing different tasks. For age_classifier we use L1Loss() as it is a continuous value and for gender_classifier we use the BCELoss() as it is a binary value.
  • We take a weighted summation of age estimation loss and gender classification loss.
  • Minimize the overall loss by performing backpropagation that optimizes weight values.
def get_model():
"""
Function to create the model with custom classifier, define loss functions and optimizer

Returns:
model: Pretrained Model with custom classifier
loss_functions: The two loss functions for the tasks
optimizer: Optimizer for the Neural Network
"""

# Load Model
model = models.vgg16(pretrained=config.PRETRAINED)

# Freeze Model
for param in model.parameters():
param.requires_grad = False

# Overwrite avgpool layer
model.avgpool = nn.Sequential(
nn.Conv2d(512, 512, kernel_size=3),
nn.MaxPool2d(2),
nn.ReLU(),
nn.Flatten()
)

# Model Custom Classifier
class GenderAgeClassifier(nn.Module):
def __init__(self):
super(GenderAgeClassifier, self).__init__()

self.intermediate = nn.Sequential(
nn.Linear(2048, 512),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(512, 128),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(128, 64),
nn.ReLU()
)

self.age_classifier = nn.Sequential(
nn.Linear(64, 1),
nn.Sigmoid()
)

self.gender_classifier = nn.Sequential(
nn.Linear(64, 1),
nn.Sigmoid()
)

def forward(self, x):
x = self.intermediate(x)
age = self.age_classifier(x)
gender = self.gender_classifier(x)

return gender, age

# Override model classifier
model.classifier = GenderAgeClassifier()

# Loss Functions
gender_criterion = nn.BCELoss()
age_criterion = nn.L1Loss()
loss_functions = gender_criterion, age_criterion

# Optimizer
optimizer = optim.Adam(
model.parameters(),
lr = config.LEARNING_RATE
)
return model.to(config.DEVICE), loss_functions, optimizer



Engine

In the Engine, we define the functions to train and evaluate a batch of data. The above points explain the steps.
# Training Function
def train_batch(data, model, criterion, optimizer):
"""
Function to train one batch of Training Data

Args:
data: Data for model training
model: Deep Learning Model
criterion: Loss Functions
optimizer: Optimizer
Returns:
total_loss: Total loss value for the batch
"""
model.train()

img, age, gender = data # Get Data
optimizer.zero_grad() # Zero the gradients

pred_gender, pred_age = model(img) # Obtain predictions

# Calculate Loss
gender_criterion, age_criterion = criterion
gender_loss = gender_criterion(pred_gender.squeeze(), gender)
age_loss = age_criterion(pred_age.squeeze(), age)

total_loss = gender_loss + age_loss
total_loss.backward()

optimizer.step() # Update Parameters

return total_loss

def validate_batch(data, model, criterion):
"""
Function to validate one batch of Data

Args:
data: Data for model valiation
model: Deep Learning Model
criterion: Loss Functions
Returns:
total_loss: Total loss value for the batch
gender_acc: Accuracy for gender predictions
age_mae: Mean Absolute Error for age predictions
"""

model.eval()
img, age, gender = data

# Obtain predictions
with torch.no_grad():
pred_gender, pred_age = model(img)

# Calculate Loss
gender_criterion, age_criterion = criterion
gender_loss = gender_criterion(pred_gender.squeeze(), gender)
age_loss = age_criterion(pred_age.squeeze(), age)

total_loss = gender_loss + age_loss
# Calculate Accuracy
pred_gender = (pred_gender > 0.5).squeeze()
gender_acc = (pred_gender == gender).float().sum()
age_mae = torch.abs(age - pred_age).float().sum()

return total_loss, gender_acc, age_mae



Training

We have everything prepared now and we can proceed to train and validate the model. The following steps are followed for model training -
  • Initialize wandb to log metrics
  • Load the model, loss function and optimizer
  • Create empty lists to store the metrics
  • Run the training loop for given epochs for the training and validation dataloaders
  • Store and log the losses and accuracies per epoch
  • Save the model
  • Finish the wandb instance
def run():
"""
Function to run the model training

Args:
path (str): Path of the model

"""

# Initialize wandb
run = wandb.init(project = "age-gender-classification")

# Load model, loss functions, optimizer
model, criterion, optimizer = get_model()

# Create empty lists to store metrics
val_gender_accuracies = []
val_age_maes = []
train_losses = []
val_losses = []

best_test_loss = 1000
start = time.time()

# Training Loop
for epoch in range(config.NUM_EPOCHS):

print(f"########## Epoch: {epoch} ##########")

epoch_train_loss, epoch_test_loss = 0, 0
val_age_mae, val_gender_acc, ctr = 0, 0, 0
_n = len(train_loader)

for ix, data in tqdm(enumerate(train_loader), total = len(train_loader)):
loss = train_batch(data, model, criterion, optimizer)
epoch_train_loss += loss.item()

for ix, data in tqdm(enumerate(val_loader), total = len(val_loader)):
loss, gender_acc, age_mae = validate_batch(data, model, criterion)
epoch_test_loss += loss.item()
val_age_mae += age_mae
val_gender_acc += gender_acc
ctr += len(data[0])

val_age_mae /= ctr
val_gender_acc /= ctr
epoch_train_loss /= len(train_loader)
epoch_test_loss /= len(val_loader)

wandb.log(
{
'Train Loss': epoch_train_loss,
'Val Loss': epoch_test_loss,
'Val Age MAE': val_age_mae,
'Val Gender Accuracy': val_gender_acc
}
)

elapsed = time.time()-start
best_test_loss = min(best_test_loss, epoch_test_loss)
print('{}/{} ({:.2f}s - {:.2f}s remaining)'.format(epoch+1, config.NUM_EPOCHS, time.time()-start, (config.NUM_EPOCHS-epoch)*(elapsed/(epoch+1))))

info = f'''Epoch: {epoch+1:03d}\tTrain Loss: {epoch_train_loss:.3f}\tTest: {epoch_test_loss:.3f}\tBest Test Loss: {best_test_loss:.4f}'''
info += f'\nGender Accuracy: {val_gender_acc*100:.2f}%\tAge MAE: {val_age_mae:.2f}\n'
print(info)

val_gender_accuracies.append(val_gender_acc)
val_age_maes.append(val_age_mae)

torch.save(model.state_dict(), config.MODEL_PATH)
wandb.finish()
To run, simply execute the above function -
run()





Test Predictions

Now that we have successfully built our model, we can use it for predictions. I downloaded a couple of images from the internet which are not present in the training or validation data and run a simple inference on them. The predictions will also be logged to wandb tables so that you don't lose them!

def plot_results(img_path):
"""
Function to run inference on test data and save the results
"""

res_path = "/content/result.png"
# Load and preprocess image
im = cv2.imread(img_path)

train = GenderAgeDataset(train_df)
im = train.preprocess_image(im).to(config.DEVICE)

# Load Model
model, _, _ = get_model()
model.load_state_dict(torch.load(config.MODEL_PATH))

# Obtain Predictions
gender, age = model(im)
pred_gender = gender.to('cpu').detach().numpy()
pred_gender = np.where(pred_gender[0][0]<0.5,'Male','Female')

pred_age = age.to('cpu').detach().numpy()
pred_age = int(pred_age[0][0]*79)

# Plot Image with Predictions
im = cv2.imread(img_path)
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)

plt.title(f"Gender: {pred_gender} Age: {pred_age}")
plt.axis('off')

plt.imshow(im)

# Save Image
plt.savefig(res_path)

# Initialize a new run
wandb.init(project = "age-gender-classification")

res = plt.imread(res_path)

# Log the image
wandb.log({"Result": [wandb.Image(res)]})

wandb.finish()

plot_results("/content/test-1.jpg")
plot_results("/content/test-2.jpg")



We can see how the model is working decently and is able to classify the gender and estimate the age of the given test images.



Conclusion

Thus, we have successfully implemented Multi-Task Learning in PyTorch and using a single model in a single pipeline we implemented two different tasks of classification and regression. You can try out more such models and find different use cases for Multi Task Learning!
The entire code is available on GitHub in the repository age-and-gender-classifier or in Colab.
If you still face any difficulties reach out to me on LinkedIn or Twitter, my messages are open :)