在PyTorch中使用LSTM:示例教程
本教程将对如何在PyTorch中使用LSTM进行介绍,同时还将提供一些代码示例和互动式可视化内容
Created on August 15|Last edited on August 17
Comment
我们将介绍的内容
在PyTorch中使用 LSTM
在本文中,我们将使用一则简短示例来演示如何在PyTorch中使用长短期记忆网络(LSTM)。您还将在下文中找到相关代码与说明。
在LSTM出现之前,NLP领域中主要会使用诸如 nnn-grams等概念来对语言进行建模。其中,nnn表示放入系列中的单词/字符数量。例如,“Hi my friend”是符合三词文法(tri-gram)的句子。但是,此类统计模型在捕获单词之间的长期交互时会失效。任何大于 4 或 5 词的句子,一度都几乎是无法计算的。
然后循环神经网络(RNN)的出现解决了这一问题。本质上而言,这类模型的体系结构是基于循环的,允许不断循环并保留部分先前信息,同时在前向传递的过程中合并新信息。LSTM 是一种特殊类型的 RNN,也是基于RNN和门控循环单元(GRU)进行的进一步改进,引入了高效的“门控”机制。

图像来源:Christopher Olah 的博客
如想了解LSTM的理论工作原理,可观看本视频。
如果您已经对LSTM有所了解,则可跳转至此处。
我们现在就开始吧。
如果您想跟着下方的示例逐步实操,可以看看我专门为此创建的实用Colab。
将LSTM添加至您的PyTorch模型中
PyTorch 的 nn 模型可帮助我们轻松地在模型中将 LSTM 添加为一层,只需使用 torch.nn.LSTM 类即可。
您应该关心的两个重要参数是:
- input_size:输入中的预计特征数量
- hidden_size:隐藏状态 中的特征数量
示例模型代码
import torch.nn as nnfrom torch.autograd import Variableclass MyModel(nn.Module):def __init__(self, ...):...self.lstm = nn.LSTM(embedding_length, hidden_size)self.label = nn.Linear(hidden_size, output_size)def forward(self):h_0 = Variable(torch.zeros(1, batch_size, self.hidden_size).cuda())c_0 = Variable(torch.zeros(1, batch_size, self.hidden_size).cuda())output, (final_hidden_state, final_cell_state) = self.lstm(input, (h_0, c_0))return self.label(final_hidden_state[-1])
训练模型
wandb.watch(model)def train_model(model, train_iter, epoch):...model.train()for idx, batch in enumerate(train_iter):...prediction = model(text)loss = loss_fn(prediction, target)wandb.log({"Training Loss": loss.item()})num_corrects = (torch.max(prediction, 1)[1].view(target.size()).data == target.data).float().sum()acc = 100.0 * num_corrects/len(batch)wandb.log({"Training Accuracy": acc.item()})...
在使用PyTorch进行LSTM实现过程中的发现
上图展示了在IMDB数据集上训练文本分类模型时的训练和评估损失及准确性。模型使用了预训练的 GLoVE 嵌入向量,并有一个带有 Dense 输出 Head 的无方向 LSTM 层。
尽管模型仅完成了 10 个 epoch 的训练,但其训练精度仍然达到了 90% 左右 。
结论
好了,我们关于在PyTorch使用LSTM的简短指南就到这里了。
此外还可在Fully Connected上查看有关LSTM的其他详尽报告。
推荐阅读
Under the Hood of Long Short Term Memory (LSTM)
This article explores how LSTM works, including how to train them with NumPy, vanish/explode the gradient, and visualize their connectivity.
LSTM RNN in Keras: Examples of One-to-Many, Many-to-One & Many-to-Many
In this report, I explain long short-term memory (LSTM) recurrent neural networks (RNN) and how to build them with Keras. Covering One-to-Many, Many-to-One & Many-to-Many.
How to stack multiple LSTMs in keras?
Add a comment
Iterate on AI agents and models faster. Try Weights & Biases today.