Skip to main content

PyTorchでのLSTMの使用:チュートリアル コーディング例あり

コードとインタラクティブな視覚化機能を備えた、PyTorchでLSTMを使用する方法を説明するチュートリアルです。
Created on August 15|Last edited on August 17
このレポートは、Saurav Maheshkarによる「Using LSTM in PyTorch: A Tutorial With Examples」の翻訳です。



紹介する内容




PyTorchでのLSTMの使用

このレポートでは、PyTorchでのLong Short-Term Memory( LSTM )の使用方法について簡単な例を見ていきます。関連するコードと手順を以下で確認することもできます。
LSTMに先立ち、NLPフィールドはほとんどの場合、言語モデリングのnnn-gramsのようなコンセプトを使用しています。ここでのnnnは、一連の単語/文字数を表しています。 たとえば、「Hi my friend」はトリグラムの単語です。しかし、このような統計モデルは、単語間の長期的な相互作用を取り込む場合に失敗してしまいます。 4または5以上のものは、当時の計算ではほとんど不可能でした。
そこで、この問題を解決したリカレントニューラルネットワーク(RNN)が登場しました。この種のモデルアーキテクチャは、基本的にループの周りに基づいています。これにより、フォワードパス中に新しい情報を組み込みながら、前の情報の一部を戻して保持することができます。長・短期メモリユニット(LSTM)は、効果的な「ゲーティング」メカニズムを導入することにより、RNNおよびゲート付き回帰型ユニット(GRU)上で改善されたRNNの特殊なタイプです。
画像: Christopher Olah 's Blog
LSTMの仕組みの理論上の知識については、このビデオをご覧ください。
すでにLSTMについてご存知の場合は、こちらまで移動することができます。

Let's get going.


では、始めましょう。







LSTMをPyTorchモデルへ追加

PyTorchのnnモジュールを使用すると、torch.nn.LSTMクラスを使用して、LSTMをモデルに簡単にレイヤーとして追加できます。
注意すべき2つの重要なパラメータは次のとおりです。-
  • input_size:インプットで予想される機能の数
  • hidden_size:非表示状態(hh)にある機能の数

サンプルモデルコード

import torch.nn as nn
from torch.autograd import Variable

class MyModel(nn.Module):
def __init__(self, ...):
...
self.lstm = nn.LSTM(embedding_length, hidden_size)
self.label = nn.Linear(hidden_size, output_size)

def forward(self):

h_0 = Variable(torch.zeros(1, batch_size, self.hidden_size).cuda())
c_0 = Variable(torch.zeros(1, batch_size, self.hidden_size).cuda())

output, (final_hidden_state, final_cell_state) = self.lstm(input, (h_0, c_0))

return self.label(final_hidden_state[-1])



モデルのトレーニング

Weights & Biasesを使用すると、wandb.log ()を使用してメトリックを簡単に記録できます。詳細については、ドキュメントを参照してください
wandb.watch(model)

def train_model(model, train_iter, epoch):
...
model.train()
for idx, batch in enumerate(train_iter):
...
prediction = model(text)
loss = loss_fn(prediction, target)
wandb.log({"Training Loss": loss.item()})
num_corrects = (torch.max(prediction, 1)[1].view(target.size()).data == target.data).float().sum()
acc = 100.0 * num_corrects/len(batch)
wandb.log({"Training Accuracy": acc.item()})
...




PyTorchを使用したLSTM実装からの観察結果

上記のグラフは、IMDBデータセットでトレーニングされたテキスト分類モデルのトレーニングおよび評価損失・精度を示しています。モデルは、事前に訓練されたGLoVE埋め込みを使用し、高出力密度ヘッドを備えた単一の一方向性LSTMレイヤーを搭載していました。モデルはわずか10エポックでトレーニングされていたにもかかわらず、トレーニング精度は~90%に達しました。

まとめ

PyTorchでのLSTMの使用に関する簡単なチュートリアルは以上となります。
W&B機能の完全版を確認するには、こちらの5分間のショートガイドをご覧ください。
LSTMの詳細を記載したFully Connectedでのその他のレポートもチェックしてみてください。

おすすめ記事


Iterate on AI agents and models faster. Try Weights & Biases today.