PyTorch でモデルを保存・読み込む方法
本記事は、PyTorch でモデルを保存・読み込む方法を解説する機械学習チュートリアルであり、バージョン管理には Weights & Biases を用います。
Created on August 12|Last edited on August 12
Comment
現実的なユースケースでは、モデルの学習には多くの計算資源と時間がかかる。多くの機械学習ワークフローでは、学習済みモデルを保存することが最後のステップとなり、その後に推論で再利用する。
学習済みモデルを保存・読み込む方法はいくつかあります PyTorch。 このチュートリアルでは、PyTorch で学習済みモデルを保存・読み込むいくつかの方法を紹介します。詳細な手順については、こちらを参照してください。 公式の PyTorch ドキュメント。
目次
PyTorch のモデルを保存・読み込む方法 state_dict (推奨)PyTorch でモデルを保存する方法 state_dictPyTorch でモデルを読み込む方法 state_dict長所短所:注意点:PyTorch のモデル全体を保存・読み込むPyTorch のモデル全体を保存するPyTorch のモデル全体を読み込む長所:短所:注意点:チェックポイントから PyTorch モデルを保存・読み込みする方法モデルのバージョン管理には W&B Artifacts を使用するPyTorch モデルを W&B Artifacts として保存するPyTorch モデルを W&B Artifacts として読み込むWeights & Biases を試す参考資料
PyTorch のモデルを保存・読み込む方法 state_dict (推奨)
A state_dict は、各レイヤー名をそのパラメータテンソルへ対応付けるだけの Python の辞書です。モデルの学習可能パラメータ(畳み込み層、全結合層など)と、登録済みバッファ(BatchNorm の running_mean)にはエントリが含まれます state_dict
使用 state_dict は、次のような理由から推奨されるアプローチです。
- まず、 state_dict モデルの本質的なパラメータ(重みやバイアスなど)のみを保存するため、ファイルサイズを小さく保てて扱いやすい。次に、柔軟性が高い点が挙げられる。state_dict は Python の辞書であるため、モデルのパラメータだけでなくオプティマイザの状態やその他のメタデータも一緒に保存でき、学習の再開や微調整が容易になる。さらに、state_dict は Python の標準 pickle モジュールと互換性があり、異なる環境間でもシンプルに保存・読み込みできるため、モデルの再現性確保が簡単になる。本ガイドでは、PyTorch で state_dict を用いてモデルを保存・読み込みする手順を解説し、効果的なモデル管理のためのベストプラクティスを紹介する。
- さらに、state_dict は Python の標準 pickle モジュールと互換性があるため、環境が異なっても保存・読み込みをシンプルに行え、モデルの再ロード手順を容易にする。本ガイドでは、PyTorch で state_dict を用いてモデルを保存・読み込みする手順を示し、効果的なモデル管理のためのベストプラクティスを解説する。
PyTorch で次の方法を使ってモデルを保存・読み込みするためのコードは次のとおり。 state_dict:
PyTorch でモデルを保存する方法 state_dict
torch.save(model.state_dict(), 'save/to/path/model.pth')
PyTorch でモデルを読み込む方法 state_dict
model = MyModelDefinition(args)model.load_state_dict(torch.load('load/from/path/model.pth'))
長所
- PyTorch は Python に依存しており pickle モジュールで、Python の辞書を簡単にシリアライズ(pickle)およびデシリアライズ(unpickle)し、更新や復元ができる。state_dict を用いると、モデルのパラメータ管理に柔軟性が生まれる。
- モデルのパラメータに加えて、オプティマイザの状態やハイパーパラメータなども、キーと値のペアとして に保存できる。 state_dictこれらはモデルを再読み込みする際に容易に参照できる。
短所:
- 読み込むには state_dict、同一のモデル定義が必要である。これがないと、保存したパラメータを正しく読み込めない。
注意点:
- 必ず呼び出してください model.eval() 推論時は、ド���ップアウトやバッチ正規化などの層を評価モードに設定すること。
- 次の方法でモデルを保存します .pt または .pth 一貫性と互換性のための拡張子。
PyTorch のモデル全体を保存・読み込む
PyTorch でモデルを保存・読み込む際は、state_dict だけでなくモデル全体を保存することもできます。 state_dictこの方法ではモデルのアーキテクチャとパラメータを一括で保存でき、最小限のコードで迅速かつ容易に復元できる。
便利ではあるものの、モデル全体の保存は一般的には推奨されない。Python の pickle 保存時に使用した特定のクラス定義やディレクトリ構成に強く依存するため、保存されたモデルファイルはその環境に密接に結び付けられます。コードをリファクタリングした場合や、別のプロジェクトでモデルを利用したい場合に、互換性の問題が生じる可能性があります。
この方法で PyTorch のモデルを保存・読み込む手順を見ていきましょう。
PyTorch のモデル全体を保存する
torch.save(model, 'save/to/path/model.pt')
PyTorch のモデル全体を読み込む
model = torch.load('load/from/path/model.pt')
長所:
- モデル全体の保存は必要なコード量が最小で、保存・読み込みを素早く行う手軽な方法です。
- モデル全体を保存・読み込むための API はシンプルで使いやすい。保存と読み込みの API はより直感的である。
短所:
- Python の pickle このモジュールが内部で使用されるため、保存されたモデルは特定のクラス定義やディレクトリ構造に強く結び付けられます。つまり、リファクタリングやファイルパスの変更を行うと、読み込み時に問題が発生する可能性があります。
- 別プロジェクトでモデル全体の保存ファイルを使用するのは困難です。保存時と同一のディレクトリ構造とクラスパスでなければならないためです。
注意点:
- 必ず呼び出してください model.eval() ドロップアウトやバッチ正規化などの層が正しく動作するよう、推論の前に必ず適切な設定を行ってください。
- 次の方法でモデルを保存します .pt または .pth モデルを保存する際は、一貫性と可読性のために拡張子を統一してください。
チェックポイントから PyTorch モデルを保存・読み込みする方法
多くの機械学習パイプラインでは、モデルのチェックポイントを定期的に、または特定の条件に基づいて保存することが不可欠です。これにより、中断が発生しても最新または最良のチェックポイントから学習を再開でき、作業の継続性を確保できます。チェックポイントは、ファインチューニングや、学習の各段階でのモデル性能の評価にも有用です。
チェックポイントを保存する際、モデルのみに限定して保存する方法は state_dict だけでは不十分です。オプティマイザの state_dict も保存してください。 state_dict、最後のエポック番号、現在の損失値、そして学習を途切れなく再開するために必要なその他の関連情報。
PyTorch のモデルチェックポイントを保存する
torch.save({'epoch': EPOCH,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': LOSS},'save/to/path/model.pth')
PyTorch のモデルチェックポイントを読み込む
model = MyModelDefinition(args)optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)checkpoint = torch.load('load/from/path/model.pth')model.load_state_dict(checkpoint['model_state_dict'])optimizer.load_state_dict(checkpoint['optimizer_state_dict'])epoch = checkpoint['epoch']loss = checkpoint['loss']
注意点:
- 学習を再開するときは、必ず呼び出してください model.train() モデルを学習モードに設定するために必ず実行してください。
- チェックポイントを読み込んだ後に推論を行う場合は、model.eval() を呼び出してください。 model.eval() モデルを評価モードに設定するために。
モデルのバージョン管理には W&B Artifacts を使用する
W&B Artifacts バージョン管理付きで PyTorch モデルを保存・読み込む強力な方法を提供する。Artifact はバージョン管理されたディレクトリと捉えられ、モデルのバージョンやその他のアセットを構造化して保存・追跡できる。
PyTorch モデルを W&B Artifacts として保存する
# Importimport wandb# Save your model.torch.save(model.state_dict(), 'save/to/path/model.pth')# Save as artifact for version control.run = wandb.init(project='your-project-name')artifact = wandb.Artifact('model', type='model')artifact.add_file('save/to/path/model.pth')run.log_artifact(artifact)run.finish()
PyTorch モデルを W&B Artifacts として読み込む
これで保存済みモデルがダウンロードされます。続いて、次の方法でモデルを読み込めます。 torch.load
import wandbrun = wandb.init()artifact = run.use_artifact('entity/your-project-name/model:v0', type='model')artifact_dir = artifact.download()run.finish()

W&B の Artifacts を使えば、PyTorch モデルの保存と読み込みがシームレスになり、バージョン管理が容易になって、プロジェクト内でのモデル反復の管理も簡単に行えます。
Weights & Biases を試す
Weights & Biases は、機械学習実験の記録と管理を支援します。ハイパーパラメータや実行時のメトリクスを記録し、結果を可視化・比較して、洞察をすばやく同僚と共有できます。
手順
- 下の緑色の「Run」ボタンをクリックしてください(初回に「Run」をクリックすると、Replit がマシンを割り当てるまで約30〜45秒かかります)
- ターミナルウィンドウの指示に従ってください(右下のペイン)
- ターミナルウィンドウのサイズを変更できます(右下)を拡大表示できます
参考資料
How To Use GPU with PyTorch
A short tutorial on using GPUs for your deep learning models with PyTorch, from checking availability to visualizing usable.
PyTorch Dropout for regularization - tutorial
Learn how to regularize your PyTorch model with Dropout, complete with a code tutorial and interactive visualizations
Image Classification Using PyTorch Lightning and Weights & Biases
This article provides a practical introduction on how to use PyTorch Lightning to improve the readability and reproducibility of your PyTorch code.
Transfer Learning Using PyTorch Lightning
In this article, we have a brief introduction to transfer learning using PyTorch Lightning, building on the image classification example from a previous article.
Add a comment
I saw ONNX file format in the video "Integrate Weights & Biases with PyTorch" https://youtu.be/G7GH0SeNBMA?t=925
It would be great to add it. Thanks!
Reply
If for some reason Optimizer state wasn't saved, then the saved model state will be useless right ? or is there any way to that we can again retrain with only model state, loss and epoch information ?
3 replies
