Skip to main content

PyTorch でモデルを保存・読み込みする方法

本記事は、PyTorch のモデルを保存・読み込みする方法を解説する機械学習チュートリアルであり、バージョン管理には Weights & Biases を使用します。
Created on August 11|Last edited on August 11
モデルの学習には多くの計算資源と時間がかかります。実運用のユースケースでは、学習済みモデルを保存するのが多くの機械学習ワークフローの最後のステップであり、その後は推論に再利用します。
学習済みモデルを保存・読み込みする方法はいくつかあります。PyTorchこのチュートリアルでは、PyTorch で学習済みモデルを保存・読み込みするいくつかの方法を紹介します。詳しい手順については、こちらを参照してください。公式PyTorchドキュメント

目次




PyTorch モデルの保存と読み込みstate_dict(推奨)

Astate_dictは、各レイヤーをそのパラメータのテンソルに対応付ける単なる Python の辞書です。モデルの学習可能なパラメータ(畳み込み層、全結合層など)と、登録されたバッファ(BatchNorm のrunning_mean)には項目が含まれますstate_dict
使用state_dictは、次の理由から推奨される方法です。例えば次のとおりです。
  • まず、state_dictは、モデルの本質的なパラメータ(重みやバイアスなど)のみを保持するため、ファイルサイズが小さくなり、扱いやすくなります。さらに柔軟性があります。state_dict は Python の辞書であるため、モデルのパラメータだけでなく optimizer の状態やその他のメタデータも保存でき、学習の再開やファインチューニングが容易になります。加えて、state_dict は Python の組み込みモジュールである pickle と互換性があるため、異なる環境間でのモデルの保存と読み込みを簡単に行えます。本ガイドでは、PyTorch で state_dict を用いてモデルを保存・読み込みする手順と、効果的なモデル管理のためのベストプラクティスを解説します。
  • 第二に、state_dict は Python の辞書型なので、モデルのパラメータだけでなく optimizer の状態やその他のメタデータも一緒に保存でき、学習の再開やファインチューニング用のモデル
  • さらに、state_dict は Python の組み込みモジュールである pickle と互換性があるため、保存と読み込みが容易で、異なる環境間でのモデル再読み込みを簡潔にできます。本ガイドでは、PyTorch で state_dict を用いてモデルを保存・読み込みする手順と、効果的なモデル管理のためのベストプラクティスを解説します。
さらに詳しく知るには、こちらをご覧ください。state_dict ここ
次のコードは、PyTorch でモデルを保存と読み込みする方法です。state_dict

PyTorch でモデルを保存する方法state_dict

torch.save(model.state_dict(), 'save/to/path/model.pth')

PyTorch でモデルを読み込むには、 を使用します。state_dict

model = MyModelDefinition(args)
model.load_state_dict(torch.load('load/from/path/model.pth'))

長所

  • PyTorch は Python のpicklePython の組み込みモジュールである pickle を使うと、辞書を容易にシリアライズ(pickle)、デシリアライズ(unpickle)、更新、復元できます。state_dict を用いると、モデルのパラメータ管理に柔軟性が生まれます。
  • モデルのパラメータに加えて、optimizer の state やハイパーパラメータなども、key-value 形式で state_dict に保存できます。state_dictこれらはモデルを再読み込みするときに簡単に参照できます。

短所

  • 読み込むにはstate_dict保存したパラメータを正しく読み込むには、同一のモデル定義が必要です。モデル定義が一致しない場合、読み込みは正しく動作しません。

注意点:

  • 必ず呼び出してくださいmodel.eval()推論時は、Dropout や BatchNorm などの層を評価モードに設定してください。
  • 次の方法でモデルを保存します.ptまたは.pth一貫性と互換性のための拡張子。

PyTorch モデル全体の保存と読み込み

PyTorch でモデルを保存・読み込みする際は、モデル全体を保存する方法も選べます。これはパラメータだけではなくモデルそのものを保存するやり方です。state_dictこの方法は、モデルのアーキテクチャとパラメータをまとめて一度に保存できるため、最小限のコードで素早く復元できます。
便��ではありますが、モデル全体の保存は一般的には推奨しません。Python のpickle保存時点のクラス定義やディレクトリ構成に強く依存するため、保存済みモデルファイルはその環境に厳密に結び付きます。コードをリファクタリングしたり、別プロジェクトでモデルを使いたい場合、この依存関係が互換性問題を引き起こすことがあります。
この方法で PyTorch モデルを保存と読み込みする手順を見ていきましょう。

PyTorch モデル全体の保存

torch.save(model, 'save/to/path/model.pt')

PyTorch モデル全体を読み込む方法

model = torch.load('load/from/path/model.pt')

長所

  • モデル全体の保存は記述量が最も少なく、素早く保存と読み込みを行う場合の手早い選択肢です。
  • モデル全体の保存と読み込み用 API は、直感的で分かりやすく使いやすいです。保存と読み込みの API はより直感的です。

短所

  • Python のpickle内部で pickle モジュールが使われるため、保存されたモデルは特定のクラス定義やディレクトリ構成に強く依存します。つまり、リファクタリングやファイルパスの変更があると、読み込み時に問題が発生する可能性があります。
  • フルモデル保存を別のプロジェクトで使うのは困難です。保存時と同一のディレクトリ構成とクラスのパスが必要になるためです。

注意点:

  • 必ず呼び出してくださいmodel.eval()推論の前には必ず model.eval() に切り替え、dropout や BatchNorm などの層が正しく動作するようにしてください。
  • 次の方法でモデルを保存します.ptまたは.pthモデルを保存する際は、拡張子は .pt または .pth を推奨し、一貫性と可読性を保ってください。

チェックポイントから PyTorch モデルを保存と読み込みする方法

多くの機械学習パイプラインでは、モデルのチェックポイントを一定間隔や特定の条件で保存することが不可欠です。こうしておけば、中断が発生しても最新または最良のチェックポイントから学習再開でき、訓練の継続性を確保できます。チェックポイントは、ファインチューニングや学習の各段階でのモデル評価にも有用です。
チェックポイントを保存する際は、モデルのみのstate_dictだけでは不十分です。オプティマイザの state_dict も保存してください。state_dict、最後のエポック番号、現在の損失、そして学習を途切れなく再開するために必要なその他の関連情報。

PyTorch モデルのチェックポイントを保存する方法

torch.save({'epoch': EPOCH,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': LOSS},
'save/to/path/model.pth')

PyTorch のモデルチェックポイントを読み込む

model = MyModelDefinition(args)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

checkpoint = torch.load('load/from/path/model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

注意点:

  • 学習を再開するときは、必ず次を呼び出してくださいmodel.train()モデルが学習モードに設定されていることを確実にするために。
  • チェックポイントを読み込んだ後に推論を行う場合は、model.eval() を呼び出してください。model.eval() モデルを評価モードに設定すること。


モデルのバージョン管理には W&B Artifacts を使用する

W&B Artifactsは、バージョン管理付きで PyTorch モデルを保存・読み込みする強力な方法を提供します。Artifact は「バージョン管理されたディレクトリ」と考えられ、モデルのバージョンやその他のアセットを構造化して保存・追跡できます。
Artifacts の詳細ここ

PyTorch モデルを W&B Artifacts として保存する方法

# Import
import wandb
# Save your model.
torch.save(model.state_dict(), 'save/to/path/model.pth')
# Save as artifact for version control.
run = wandb.init(project='your-project-name')
artifact = wandb.Artifact('model', type='model')
artifact.add_file('save/to/path/model.pth')
run.log_artifact(artifact)
run.finish()

W&B Artifacts として PyTorch モデルを読み込む方法

これで保存済みモデルがダウンロードされます。続いて、次のようにモデルを読み込めます。torch.load
import wandb
run = wandb.init()

artifact = run.use_artifact('entity/your-project-name/model:v0', type='model')
artifact_dir = artifact.download()

run.finish()

図1ライブの Artifact ダッシュボードを確認してくださいここ

W&B Artifacts を使えば、PyTorch モデルの保存と読み込みがシームレスになり、バージョン管理が容易になります。これにより、プロジェクト内でのモデルの反復管理も簡単に行えます。

Weights & Biases を試す

Weights & Biases は、機械学習の実験を体系的に記録・管理するのに役立ちます。ハイパーパラメータや実行中に得られる指標をログに残し、結果を可視化・比較したうえで、同僚と素早く共有できます。
はじめに5分で始めるか、Replit で2つの簡単な実験を実行して、W&B が作業の整理にどう役立つかを確認しましょう。次の手順に従ってください。
手順
  1. 下の緑色の「Run」ボタンをクリックしてください(初回の実行時は、Replit がマシンを割り当てるために約30〜45秒かかります)
  2. ターミナルウィンドウの指示に従ってください
  3. ターミナルウィンドウのサイズを変更できます(右下)サイズを拡大



参考資料


Hongbo Miao
Hongbo Miao •  
I saw ONNX file format in the video "Integrate Weights & Biases with PyTorch" https://youtu.be/G7GH0SeNBMA?t=925 It would be great to add it. Thanks!
Reply
Arnav Das
Arnav Das •  
If for some reason Optimizer state wasn't saved, then the saved model state will be useless right ? or is there any way to that we can again retrain with only model state, loss and epoch information ?
3 replies