How to Use LSTMs in PyTorch

A tutorial covering how to use LSTMs in PyTorch, complete with code and interactive visualizations. Made by Saurav Maheshkar using Weights & Biases
Saurav Maheshkar

Sections We'll Cover

  1. Introduction
  2. Add LSTM to Your Model
  3. Training Your LSTM Model
  4. Observations from Our Implementation
  5. Conclusion

Introduction πŸ‘‹πŸ»

In this report, we'll walk through a quick example showcasing how you can get started with using long short-term memory (LSTMs) in PyTorch. You'll find the relevant code & instructions below.
Prior to LSTMs the NLP industry mostly used concepts like n-grams for language modelling, where n denotes the number of words/characters taken in series. For instance, "Hi my friend" is a tri-gram. But these kind of statistical models fail in the case of capturing long-term interactions between words. Anything higher than 4 or 5, was pretty much impossible on the compute available at the time.
Eventually Recurrent Neural Networks (RNN) came into existence which solved this problem. These kind of model architectures are essentially based around loops which allow them to circle back and keep some part of the prior information while incorporating new information during the forward pass. Long Short Term Memory Units (LSTM) are a special type of RNN which further improved upon RNNs and Gated Recurrent Units (GRUs) by introducing an effective "gating" mechanism.
Image Credits: Christopher Olah's Blog
For a Theoretical Understanding of how LSTM's work, check out this video on the Weights and Biases Youtube Channel.
Let's get going:

Quick Start for LSTM in Colab

1️⃣ Add LSTM to your Model

PyTorch's nn Module allows us to easily add LSTM as a layer in your models using the torch.nn.LSTM class. The two important parameters you should care about are:-

🏠 Sample Model Code

import torch.nn as nnfrom torch.autograd import Variableclass MyModel(nn.Module): def __init__(self, ...): ... self.lstm = nn.LSTM(embedding_length, hidden_size) self.label = nn.Linear(hidden_size, output_size) def forward(self): h_0 = Variable(torch.zeros(1, batch_size, self.hidden_size).cuda()) c_0 = Variable(torch.zeros(1, batch_size, self.hidden_size).cuda()) output, (final_hidden_state, final_cell_state) = self.lstm(input, (h_0, c_0)) return self.label(final_hidden_state[-1])

2️⃣ Training Your Model

Using Weights and Biases you can easily log your metrics using wandb.log(). See docs for full details.
wandb.watch(model)def train_model(model, train_iter, epoch): ... model.train() for idx, batch in enumerate(train_iter): ... prediction = model(text) loss = loss_fn(prediction, target) wandb.log({"Training Loss": loss.item()}) num_corrects = (torch.max(prediction, 1)[1].view(target.size()).data == target.data).float().sum() acc = 100.0 * num_corrects/len(batch) wandb.log({"Training Accuracy": acc.item()}) ...

πŸ‘ Observations from our LSTM Implementation Using PyTorch

The graphs above show the Training and Evaluation Loss and Accuracy for a Text Classification Model trained on the IMDB dataset. The model used pretrained GLoVE embeddings and had a single unidirectional LSTM layer with Dense Output Head. Even though the model was trained on only 10 epochs it attained a decent Training Accuracy of ~90%.

Conclusion

And that wraps up our short tutorial on using LSTMs in PyTorch. To see the full suite of wandb features please check out this short 5 minutes guide.
Check out these other reports on Fully Connected covering LSTMs in much more detail.
Report Gallery