Using LSTM in PyTorch: A Tutorial With Examples
This article provides a tutorial on how to use Long Short-Term Memory (LSTM) in PyTorch, complete with code examples and interactive visualizations using W&B.
Created on September 9|Last edited on June 26
Comment
In this article, we'll walk through a quick example showcasing how you can get started with using Long Short-Term Memory (LSTMs) in PyTorch. You'll also find the relevant code & instructions below. Here's what we'll be covering:
Table of Contents
Using LSTM In PyTorchAdding LSTM To Your PyTorch ModelSample Model CodeTraining Your Model Observations from our LSTM Implementation Using PyTorchConclusion Recommended Reading
Using LSTM In PyTorch
Prior to LSTMs, the NLP field mostly used concepts like -grams for language modeling, where denotes the number of words/characters taken in series. For instance, "Hi my friend" is a word tri-gram. But this kind of statistical model fails in the case of capturing long-term interactions between words. Anything higher than 4 or 5, was pretty much impossible on the compute available at the time.
Eventually, Recurrent Neural Networks (RNN) came into existence, which solved this problem. This kind of model architecture is essentially based around loops which allow them to circle back and keep some part of the prior information while incorporating new information during the forward pass. Long Short-Term Memory Units (LSTM) are a special type of RNN that further improved upon RNNs and Gated Recurrent Units (GRUs) by introducing an effective "gating" mechanism.

Image Credits: Christopher Olah's Blog
For a Theoretical Understanding of how LSTMs work, check out this video.
If you're already familiar with LSTM, you can jump to here.
Let's get going.
If you'd like to follow along with the example below, here's a handy Colab I created to allow you to do just that.
Adding LSTM To Your PyTorch Model
PyTorch's nn Module allows us to easily add LSTM as a layer to our models using the torch.nn.LSTM class.
The two important parameters you should care about are:-
- input_size: number of expected features in the input
- hidden_size: number of features in the hidden state
Sample Model Code
import torch.nn as nnfrom torch.autograd import Variableclass MyModel(nn.Module):def __init__(self, ...):...self.lstm = nn.LSTM(embedding_length, hidden_size)self.label = nn.Linear(hidden_size, output_size)def forward(self):h_0 = Variable(torch.zeros(1, batch_size, self.hidden_size).cuda())c_0 = Variable(torch.zeros(1, batch_size, self.hidden_size).cuda())output, (final_hidden_state, final_cell_state) = self.lstm(input, (h_0, c_0))return self.label(final_hidden_state[-1])
Training Your Model
Using Weights & Biases, you can easily log your metrics using wandb.log(). See docs for full details.
wandb.watch(model)def train_model(model, train_iter, epoch):...model.train()for idx, batch in enumerate(train_iter):...prediction = model(text)loss = loss_fn(prediction, target)wandb.log({"Training Loss": loss.item()})num_corrects = (torch.max(prediction, 1)[1].view(target.size()).data == target.data).float().sum()acc = 100.0 * num_corrects/len(batch)wandb.log({"Training Accuracy": acc.item()})...
Run set
4
Observations from our LSTM Implementation Using PyTorch
The graphs above show the Training and Evaluation Loss and Accuracy for a Text Classification Model trained on the IMDB dataset. The model used pre-trained GLoVE embeddings and had a single unidirectional LSTM layer with Dense Output Head.
Even though the model was trained on only 10 epochs, it attained a decent Training Accuracy of ~90%.
Conclusion
And that wraps up our short tutorial on using LSTMs in PyTorch.
Check out these other reports on Fully Connected covering LSTMs in much more detail.
Recommended Reading
Under the Hood of Long Short Term Memory (LSTM)
This article explores how LSTM works, including how to train them with NumPy, vanish/explode the gradient, and visualize their connectivity.
LSTM RNN in Keras: Examples of One-to-Many, Many-to-One & Many-to-Many
In this report, I explain long short-term memory (LSTM) recurrent neural networks (RNN) and how to build them with Keras. Covering One-to-Many, Many-to-One & Many-to-Many.
How to stack multiple LSTMs in keras?
Add a comment
final_hidden_state[-1]
Why have you used -1 index in the final hidden state?
Isnt the final hidden state already the output(hidden state) of the final time step?
Reply
Iterate on AI agents and models faster. Try Weights & Biases today.