Skip to main content

在PyTorch中使用LSTM:示例教程

本教程将对如何在PyTorch中使用LSTM进行介绍,同时还将提供一些代码示例和互动式可视化内容
Created on August 15|Last edited on August 17
本报告是作者Saurav Maheshkar所写的"Using LSTM in PyTorch: A Tutorial With Examples"的翻译




我们将介绍的内容




在PyTorch中使用 LSTM

在本文中,我们将使用一则简短示例来演示如何在PyTorch中使用长短期记忆网络(LSTM)。您还将在下文中找到相关代码与说明。
在LSTM出现之前,NLP领域中主要会使用诸如 nnn-grams等概念来对语言进行建模。其中,nnn表示放入系列中的单词/字符数量。例如,“Hi my friend”是符合三词文法(tri-gram)的句子。但是,此类统计模型在捕获单词之间的长期交互时会失效。任何大于 4 或 5 词的句子,一度都几乎是无法计算的。
然后循环神经网络(RNN)的出现解决了这一问题。本质上而言,这类模型的体系结构是基于循环的,允许不断循环并保留部分先前信息,同时在前向传递的过程中合并新信息。LSTM 是一种特殊类型的 RNN,也是基于RNN和门控循环单元(GRU)进行的进一步改进,引入了高效的“门控”机制。
图像来源:Christopher Olah 的博客
如想了解LSTM的理论工作原理,可观看本视频。
如果您已经对LSTM有所了解,则可跳转至此处。

我们现在就开始吧。

如果您想跟着下方的示例逐步实操,可以看看我专门为此创建的实用Colab。




将LSTM添加至您的PyTorch模型中

PyTorch 的 nn 模型可帮助我们轻松地在模型中将 LSTM 添加为一层,只需使用 torch.nn.LSTM 类即可。
您应该关心的两个重要参数是:
  • input_size:输入中的预计特征数量
  • hidden_size:隐藏状态 hh 中的特征数量

示例模型代码

import torch.nn as nn
from torch.autograd import Variable

class MyModel(nn.Module):
def __init__(self, ...):
...
self.lstm = nn.LSTM(embedding_length, hidden_size)
self.label = nn.Linear(hidden_size, output_size)

def forward(self):

h_0 = Variable(torch.zeros(1, batch_size, self.hidden_size).cuda())
c_0 = Variable(torch.zeros(1, batch_size, self.hidden_size).cuda())

output, (final_hidden_state, final_cell_state) = self.lstm(input, (h_0, c_0))

return self.label(final_hidden_state[-1])



训练模型

借助Weights & Biases,您可轻松使用 wandb.log() 来记录指标。查阅文档以了解完整细节
wandb.watch(model)

def train_model(model, train_iter, epoch):
...
model.train()
for idx, batch in enumerate(train_iter):
...
prediction = model(text)
loss = loss_fn(prediction, target)
wandb.log({"Training Loss": loss.item()})
num_corrects = (torch.max(prediction, 1)[1].view(target.size()).data == target.data).float().sum()
acc = 100.0 * num_corrects/len(batch)
wandb.log({"Training Accuracy": acc.item()})
...




在使用PyTorch进行LSTM实现过程中的发现

上图展示了在IMDB数据集上训练文本分类模型时的训练和评估损失及准确性。模型使用了预训练的 GLoVE 嵌入向量,并有一个带有 Dense 输出 Head 的无方向 LSTM 层。
尽管模型仅完成了 10 个 epoch 的训练,但其训练精度仍然达到了 90% 左右 。

结论

好了,我们关于在PyTorch使用LSTM的简短指南就到这里了。
如需了解W&B的完整功能,请查看这一 5分钟简短指南
此外还可在Fully Connected上查看有关LSTM的其他详尽报告。

推荐阅读


Iterate on AI agents and models faster. Try Weights & Biases today.