Multi Task Learning in PyTorch
Implement an Age and Gender Classifier using a single model
Created on February 10|Last edited on February 19
Comment
A number of Deep Learning projects are focused on accomplishing a single desired task using a model. Like an animal classifier model has one task of classifying all animals correctly.
But in a lot of instances we might have several similar tasks which need to be predicted.
Take the example of an autonomous car. Given the input live feed, the car needs to perform several tasks like lane detection, sign board detection, obstacle detection and so on. You have two choices here, either to create separate training pipelines for all tasks or combine them together in a single pipeline as all these tasks are similar and have same input (i.e., the live feed from the sensors of the car).
This is where Multi Task learning comes into picture.
Multi-Task Learning is a branch of research where a single/few inputs are used to predict several different but ultimately connected outputs
With the help of this article, I will demonstrate how you can perform Age and Gender Classification on the same dataset in a single training pipeline.
We will use a dataset containing important features such as human images, age, gender and so on.
Let's start directly with the code!
Contents
ContentsAge and Gender Classification using Deep LearningPreliminariesGlobal ConfigurationDownload DatasetLoad DatasetDataset ClassCreate Datasets and DataLoadersModelEngineTrainingTest PredictionsConclusion
Age and Gender Classification using Deep Learning
We begin by importing the necessary libraries that will be required.
Preliminaries
# Necessitiesimport osimport cv2import globimport timeimport randomimport pandas as pdimport numpy as np# PyTorchimport torchimport torch.nn as nnimport torch.optim as optimimport torch.nn.functional as Ffrom torch.utils.data import Dataset, DataLoaderimport torchvisionfrom torchvision import datasets, models, transforms# Data Visualizationimport matplotlib.pyplot as plt%matplotlib inline# Progress Barfrom tqdm.notebook import tqdm# File downloadfrom pydrive.auth import GoogleAuthfrom pydrive.drive import GoogleDrivefrom google.colab import authfrom oauth2client.client import GoogleCredentials# wandb!pip install wandbimport wandbwandb.login()
Global Configuration
Next we define the Global Configuration which incorporates the file paths and hyperparameters.
class config:DIRECTORY_PATH = '/content'TRAIN_FILE_PATH = os.path.join(DIRECTORY_PATH, 'fairface-label-train.csv')VAL_FILE_PATH = os.path.join(DIRECTORY_PATH, 'fairface-label-val.csv')MODEL_PATH = os.path.join(DIRECTORY_PATH, '/models/model.bin')DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")SEED = 42NUM_EPOCHS = 5IMAGE_SIZE = 224BATCH_SIZE = 32LEARNING_RATE = 1e-4PRETRAINED = True
To be able to reproduce the results, it is very important to seed everything which is done using the below function.
def seed_everything(SEED = config.SEED):random.seed(SEED)np.random.seed(SEED)torch.manual_seed(SEED)torch.cuda.manual_seed(SEED)torch.cuda.manual_seed_all(SEED)os.environ['PYTHONHASHSEED']=str(SEED)
Download Dataset
Then, we will download the dataset from google drive using the authentication method as shown.
auth.authenticate_user()gauth = GoogleAuth()gauth.credentials = GoogleCredentials.get_application_default()drive = GoogleDrive(gauth)def getFile_from_drive( file_id, name ):downloaded = drive.CreateFile({'id': file_id})downloaded.GetContentFile(name)getFile_from_drive('1Z1RqRo0_JiavaZw2yzZG6WETdZQ8qX86', 'fairface-img-margin025-trainval.zip')getFile_from_drive('1k5vvyREmHDW5TSM9QgB04Bvc8C8_7dl-', 'fairface-label-train.csv')getFile_from_drive('1_rtz1M1zhvS0d5vVoXUamnohB6cJ02iJ', 'fairface-label-val.csv')!unzip -qq fairface-img-margin025-trainval.zip
Load Dataset
The dataset has been downloaded and now we can create the training and validation dataframes.
train_df = pd.read_csv(config.TRAIN_FILE_PATH)val_df = pd.read_csv(config.VAL_FILE_PATH)train_df.head()
From the above dataset we will be using only the file, age and gender features for training and validation.
Note: Whenever working with continuous features, scaling becomes important. Scaling is simply achieved by dividing a feature with its maximum value. Thus in this case to scales the images we divide them by 255 (maximum pixel value) and to scale the age feature we divide it by 79 (maximum age value).
# Maximum Value of Agemax(train_df.age) # prints 79
Dataset Class
We move ahead to creating the PyTorch Dataset Class. Generally the dataset class uses three important functions which are -
- __init__() - To initialize the Dataset module
- __len__() - To provide the number of samples
- __getitem__() - To obtain one sample at a time
But in this case, we preprocess the image inside the dataset class itself and collate i.e., combine as well. To achieve this, we use two more functions which are -
4. preprocess_image() - To resize, convert to tensors, permute, normalize and scale the images
5. collate_fn() - To combine the preprocessed features all together, scale the features and obtain tensors.
In the preprocess_image() function, we perform the following preprocessing -
- Resize all images to the size 224 x 224
- Convert images to PyTorch Tensors
- Permute Images to obtain the dimensions of the PyTorch accepted form of channels x height x width
- Scale the images by dividing them by the highest pixel value of 255
- Normalize the image to the ImageNet format as we are using a pretrained model
class GenderAgeDataset(Dataset):def __init__(self, df, transform = None):self.df = dfself.normalize = transforms.Normalize(mean = [0.485, 0.456, 0.406],std = [0.229, 0.224, 0.225])def __len__(self):return len(self.df)def __getitem__(self, index):data = self.df.iloc[index].squeeze()file = data.fileage = data.agegender = data.gender == "Female"img = cv2.imread(file)img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)return img, age, genderdef preprocess_image(self, img):img = cv2.resize(img, (config.IMAGE_SIZE, config.IMAGE_SIZE))img = torch.tensor(img).permute(2,0,1)img = self.normalize(img/255.)return img[None]def collate_fn(self, batch):imgs, ages, genders = [], [], []for img, age, gender in batch:img = self.preprocess_image(img)imgs.append(img)age = float((int(age)/79)) # Scaling the age with maximum valueages.append(age)gender = float(gender)genders.append(gender)ages, genders = [torch.tensor(x).to(config.DEVICE).float() for x in [ages, genders]]imgs = torch.cat(imgs).to(config.DEVICE)return imgs, ages, genders
Create Datasets and DataLoaders
We define a function get_data() which defines the datasets and provides the training and validation dataloaders.
def get_data():"""Function to obtain the Training and Validation DataLoadersReturns:train_loader: Training DataLoaderval_loader: Validation DataLoader"""train = GenderAgeDataset(train_df)val = GenderAgeDataset(val_df)train_loader = DataLoader(train,batch_size = config.BATCH_SIZE,shuffle = True,drop_last = True,collate_fn = train.collate_fn)val_loader = DataLoader(val,batch_size = config.BATCH_SIZE,shuffle = False,collate_fn = val.collate_fn)return train_loader, val_loader
train_loader, val_loader = get_data()# Test the DataLoaderimg, age, gender = next(iter(train_loader))print(img.shape, age.shape, gender.shape)# prints torch.Size([32, 3, 224, 224]) torch.Size([32]) torch.Size([32])
In the above output:
- torch.Size([32, 3, 224, 224]) signifies the image in the form batch_size x channels x height x width
- The two torch.Size([32]) signifies the unique values of age and gender for all the images.
Model
To implement this task, we will use the pretrained version of the VGG16 model.
model = models.vgg16(pretrained=True)model

VGG16 Model Architecture
Note that in the VGG16 model architecture there are three key sub-modules in the model, namely - features, avgpool and classifier. For our work, we will freeze the features sub-module and override the avgpool and classifier sub-modules with custom layers.
Freezing implies that the model parameters (weights and biases) for the frozen layers will not be updated as we want to preserve the learning model has gained while pre-training. This is achieved by setting the param.requires_grad = False
To build the model, we consider the following points -
- All the layers under the features sub-module remain the same.
- We override the avgpool sub-module with a custom convolutional layer.
- We override the classifier sub-module by creating custom layers in the GenderAgeClassifier class.
- In the classifier, we create two separate layers (age_classifier and gender_classifier) branching out from the intermediate layer.
- Note that we use two different loss functions as we are performing different tasks. For age_classifier we use L1Loss() as it is a continuous value and for gender_classifier we use the BCELoss() as it is a binary value.
- We take a weighted summation of age estimation loss and gender classification loss.
- Minimize the overall loss by performing backpropagation that optimizes weight values.
def get_model():"""Function to create the model with custom classifier, define loss functions and optimizerReturns:model: Pretrained Model with custom classifierloss_functions: The two loss functions for the tasksoptimizer: Optimizer for the Neural Network"""# Load Modelmodel = models.vgg16(pretrained=config.PRETRAINED)# Freeze Modelfor param in model.parameters():param.requires_grad = False# Overwrite avgpool layermodel.avgpool = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3),nn.MaxPool2d(2),nn.ReLU(),nn.Flatten())# Model Custom Classifierclass GenderAgeClassifier(nn.Module):def __init__(self):super(GenderAgeClassifier, self).__init__()self.intermediate = nn.Sequential(nn.Linear(2048, 512),nn.ReLU(),nn.Dropout(0.4),nn.Linear(512, 128),nn.ReLU(),nn.Dropout(0.4),nn.Linear(128, 64),nn.ReLU())self.age_classifier = nn.Sequential(nn.Linear(64, 1),nn.Sigmoid())self.gender_classifier = nn.Sequential(nn.Linear(64, 1),nn.Sigmoid())def forward(self, x):x = self.intermediate(x)age = self.age_classifier(x)gender = self.gender_classifier(x)return gender, age# Override model classifiermodel.classifier = GenderAgeClassifier()# Loss Functionsgender_criterion = nn.BCELoss()age_criterion = nn.L1Loss()loss_functions = gender_criterion, age_criterion# Optimizeroptimizer = optim.Adam(model.parameters(),lr = config.LEARNING_RATE)return model.to(config.DEVICE), loss_functions, optimizer
Engine
In the Engine, we define the functions to train and evaluate a batch of data. The above points explain the steps.
# Training Functiondef train_batch(data, model, criterion, optimizer):"""Function to train one batch of Training DataArgs:data: Data for model trainingmodel: Deep Learning Modelcriterion: Loss Functionsoptimizer: OptimizerReturns:total_loss: Total loss value for the batch"""model.train()img, age, gender = data # Get Dataoptimizer.zero_grad() # Zero the gradientspred_gender, pred_age = model(img) # Obtain predictions# Calculate Lossgender_criterion, age_criterion = criteriongender_loss = gender_criterion(pred_gender.squeeze(), gender)age_loss = age_criterion(pred_age.squeeze(), age)total_loss = gender_loss + age_losstotal_loss.backward()optimizer.step() # Update Parametersreturn total_lossdef validate_batch(data, model, criterion):"""Function to validate one batch of DataArgs:data: Data for model valiationmodel: Deep Learning Modelcriterion: Loss FunctionsReturns:total_loss: Total loss value for the batchgender_acc: Accuracy for gender predictionsage_mae: Mean Absolute Error for age predictions"""model.eval()img, age, gender = data# Obtain predictionswith torch.no_grad():pred_gender, pred_age = model(img)# Calculate Lossgender_criterion, age_criterion = criteriongender_loss = gender_criterion(pred_gender.squeeze(), gender)age_loss = age_criterion(pred_age.squeeze(), age)total_loss = gender_loss + age_loss# Calculate Accuracypred_gender = (pred_gender > 0.5).squeeze()gender_acc = (pred_gender == gender).float().sum()age_mae = torch.abs(age - pred_age).float().sum()return total_loss, gender_acc, age_mae
Training
We have everything prepared now and we can proceed to train and validate the model. The following steps are followed for model training -
- Initialize wandb to log metrics
- Load the model, loss function and optimizer
- Create empty lists to store the metrics
- Run the training loop for given epochs for the training and validation dataloaders
- Store and log the losses and accuracies per epoch
- Save the model
- Finish the wandb instance
def run():"""Function to run the model trainingArgs:path (str): Path of the model"""# Initialize wandbrun = wandb.init(project = "age-gender-classification")# Load model, loss functions, optimizermodel, criterion, optimizer = get_model()# Create empty lists to store metricsval_gender_accuracies = []val_age_maes = []train_losses = []val_losses = []best_test_loss = 1000start = time.time()# Training Loopfor epoch in range(config.NUM_EPOCHS):print(f"########## Epoch: {epoch} ##########")epoch_train_loss, epoch_test_loss = 0, 0val_age_mae, val_gender_acc, ctr = 0, 0, 0_n = len(train_loader)for ix, data in tqdm(enumerate(train_loader), total = len(train_loader)):loss = train_batch(data, model, criterion, optimizer)epoch_train_loss += loss.item()for ix, data in tqdm(enumerate(val_loader), total = len(val_loader)):loss, gender_acc, age_mae = validate_batch(data, model, criterion)epoch_test_loss += loss.item()val_age_mae += age_maeval_gender_acc += gender_accctr += len(data[0])val_age_mae /= ctrval_gender_acc /= ctrepoch_train_loss /= len(train_loader)epoch_test_loss /= len(val_loader)wandb.log({'Train Loss': epoch_train_loss,'Val Loss': epoch_test_loss,'Val Age MAE': val_age_mae,'Val Gender Accuracy': val_gender_acc})elapsed = time.time()-startbest_test_loss = min(best_test_loss, epoch_test_loss)print('{}/{} ({:.2f}s - {:.2f}s remaining)'.format(epoch+1, config.NUM_EPOCHS, time.time()-start, (config.NUM_EPOCHS-epoch)*(elapsed/(epoch+1))))info = f'''Epoch: {epoch+1:03d}\tTrain Loss: {epoch_train_loss:.3f}\tTest: {epoch_test_loss:.3f}\tBest Test Loss: {best_test_loss:.4f}'''info += f'\nGender Accuracy: {val_gender_acc*100:.2f}%\tAge MAE: {val_age_mae:.2f}\n'print(info)val_gender_accuracies.append(val_gender_acc)val_age_maes.append(val_age_mae)torch.save(model.state_dict(), config.MODEL_PATH)wandb.finish()
To run, simply execute the above function -
run()
Test Predictions
Now that we have successfully built our model, we can use it for predictions. I downloaded a couple of images from the internet which are not present in the training or validation data and run a simple inference on them. The predictions will also be logged to wandb tables so that you don't lose them!
def plot_results(img_path):"""Function to run inference on test data and save the results"""res_path = "/content/result.png"# Load and preprocess imageim = cv2.imread(img_path)train = GenderAgeDataset(train_df)im = train.preprocess_image(im).to(config.DEVICE)# Load Modelmodel, _, _ = get_model()model.load_state_dict(torch.load(config.MODEL_PATH))# Obtain Predictionsgender, age = model(im)pred_gender = gender.to('cpu').detach().numpy()pred_gender = np.where(pred_gender[0][0]<0.5,'Male','Female')pred_age = age.to('cpu').detach().numpy()pred_age = int(pred_age[0][0]*79)# Plot Image with Predictionsim = cv2.imread(img_path)im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)plt.title(f"Gender: {pred_gender} Age: {pred_age}")plt.axis('off')plt.imshow(im)# Save Imageplt.savefig(res_path)# Initialize a new runwandb.init(project = "age-gender-classification")res = plt.imread(res_path)# Log the imagewandb.log({"Result": [wandb.Image(res)]})wandb.finish()
plot_results("/content/test-1.jpg")plot_results("/content/test-2.jpg")
We can see how the model is working decently and is able to classify the gender and estimate the age of the given test images.
Conclusion
Thus, we have successfully implemented Multi-Task Learning in PyTorch and using a single model in a single pipeline we implemented two different tasks of classification and regression. You can try out more such models and find different use cases for Multi Task Learning!
Add a comment