Skip to main content

PyTorch でモデルを保存・読み込む方法

この記事は、PyTorch でモデルを保存・読み込みする方法を解説する機械学習チュートリアルであり、バージョン管理には Weights & Biases(W&B)を使用します。
Created on August 12|Last edited on August 12
実運用のユースケースでは、モデルの学習には多くの計算資源と時間がかかります。学習済みモデルを保存することは、多くの機械学習ワークフローにおける最後のステップであり、その後に推論で再利用します。
学習済みモデルを保存および読み込む方法は、PyTorch にはいくつかあります。PyTorch。このチュートリアルでは、PyTorch で学習済みモデルを保存・読み込みするいくつかの方法を紹介します。詳しい手順は、こちらを参照してください。公式 PyTorch ドキュメント

目次




PyTorch のモデルを���存して読み込む方法 state_dict (推奨)

A state_dict は、各レイヤー名からそのパラメータのテンソルへの対応を保持する単なる Python の辞書です。モデルの学習可能なパラメータ(畳み込み層、全結合層など)や、登録されたバッファ(Batch Normalization のバッファや Dropout に関連するものなど)を含みます。 running_mean)にエントリがある state_dict
使用 state_dict は、次のような理由から推奨される方法です。例えば、次の点が挙げられます。"
  • まず、 state_dict モデルの本質的なパラメータ(重みやバイアスなど)だけを保存するため、ファイルサイズを小さく保てて扱いやすくなります。さらに柔軟性も高く、state_dict は Python の辞書型であるため、モデルのパラメータに加えてオプティマイザの state_dict やその他のメタデータも一緒に保存でき、学習の再開やモデルの微調整/ファインチューニングが容易になります。加えて、state_dict は Python 標準の pickle モジュールと互換性があるため、異なる環境間でもシンプルに保存・読み込みができ、モデルの再現性確保に役立ちます。本ガイドでは、PyTorch で state_dict を用いてモデルを保存・読み込む手順と、効果的なモデル管理のためのベストプラクティスを解説します。
  • 第二に、state_dict は Python の辞書型なので、モデルのパラメータだけでなくオプティマイザの state_dict(状態)やその他のメタデータも保存でき、学習の再開が容易になります。モデルを微調整する
  • さらに、state_dict は Python の標準 pickle モジュールと互換性があるため、簡単に保存・読み込みができ、異なる環境間でモデルを再読み込みする手順をシンプルにします。本ガイドでは、PyTorch で state_dict を使用してモデルを保存・読み込みする手順を解説し、効果的なモデル管理のベストプラクティスを説明します。
詳しくは次をご覧ください state_dict テキストを送ってください。セクションごとに翻訳し、指摘されたスタイルガイドと用語統一ルールに従って改善します。 ここ
PyTorch でモデルを保存・読み込みするためのコードは次のとおりです state_dict

PyTorch でモデルを保存するには state_dict

torch.save(model.state_dict(), 'save/to/path/model.pth')

PyTorch でモデルを読み込むには state_dict

model = MyModelDefinition(args)
model.load_state_dict(torch.load('load/from/path/model.pth'))

長所

  • PyTorch は Python の pickle Python の辞書を簡単に pickle 化・unpickle 化し、更新・復元できるモジュールです。state_dict を使用すると、モデルのパラメータ管理に柔軟性を持たせることができます。
  • モデルのパラメータに加えて、オプティマイザのstate_dictやハイパーパラメータなどの追加要素も、キーと値のペアとして保存できます。 state_dictこれらはモデルを再読み込みするときに簡単に参照できます。

短所

  • 読み込むには state_dictそのため、まったく同じモデル定義が必要です。これがないと、保存したパラメータ(state_dict)を正しく読み込むことはできません。

注意点

  • 必ず呼び出してください model.eval() 推論時には、Dropout や Batch Normalization といったレイヤーを評価モードに設定してください。
  • モデルを保存する .pt または .pth 一貫性と互換性のために拡張子を統一してください。

PyTorch のモデル全体を保存・読み込みする方法

PyTorch でモデルを保存・読み込みする際には、パラメータだけではなくモデル全体を保存することもできます。 state_dictこの方法では、モデルのアーキテクチャとパラメータをまとめて一度に保存でき、最小限のコードで素早く復元できます。
便利ではありますが、モデル全体の保存は一般的には推奨されません。Python の pickle 保存時点のクラス定義やディレクトリ構成に強く依存するため、保存したモデルファイルはその環境に密結合になります。この依存性により、コードをリファクタリングした場合や別のプロジェクトでモデルを使いたい場合に、互換性の問題が発生する可能性があります。
この方法で PyTorch のモデルを保存・読み込みする手���を見ていきましょう。

PyTorch のモデル全体を保存する

torch.save(model, 'save/to/path/model.pt')

PyTorch モデル全体を読み込む

model = torch.load('load/from/path/model.pt')

長所

  • モデル全体の保存は必要なコード量が最も少なく、保存と読み込みを素早く行うための手軽な選択肢です。
  • モデル全体を保存・読み込みするための API は、シンプルで使いやすいです。保存と読み込みの API は、より直感的です。

短所

  • Python の pickle module が内部で使用されるため、保存されたモデルは特定のクラス定義やディレクトリ構造に強く結び付けられます。つまり、リファクタリングやファイルパスの変更があると、読み込み時に問題が発生する可能性があります。
  • 別プロジェクトでモデル全体の保存ファイルを使用するのは難しく、保存時と同一のディレクトリ構成やクラスのパスが必要になります。

注意点

  • 必ず呼び出してください model.eval() 推論を行う前に、dropout や Batch Normalization などの層が正しく動作するようにするためです。
  • モデルを保存する .pt または .pth モデルを保存する際は、統一性と可読性を保つために .pt または .pth の拡張子を使用してください。

チェックポイントから PyTorch モデルを保存・読み込みする

多くの機械学習パイプラインでは、モデルのチェックポイントを一定間隔や特定の条件に基づいて保存することが不可欠です。これにより、中断が発生しても最新または最良のチェックポイントから学習を再開でき、作業を継続できます。チェックポイントは、微調整や学習の各段階でのモデル性能評価にも有用です。
チェックポイントを保存する際は、モデルのみの state_dict だけでなく、オプティマイザの state_dict やエポック、損失、ハイパーパラメータなどのメタデータも一緒に保存してください。 state_dict だけでは不十分です。オプティマイザの state_dict(状態)も保存してください。 state_dict、最後のエポック番号、現在の損失値、および学習を途切れなく再開するために必要なその他の関連情報。

PyTorch のモデルチェックポイントを保存する

torch.save({'epoch': EPOCH,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': LOSS},
'save/to/path/model.pth')

PyTorch のモデルチェックポイントを読み込む

model = MyModelDefinition(args)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

checkpoint = torch.load('load/from/path/model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

注意点

  • 学習を再開するときは、必ず model.train() を呼び出してください。 model.train() モデルを学習モードに設定することを確実にするために。
  • チェックポイントを読み込んだ後に推論を行う場合は、次を呼び出してください model.eval() モデルを評価モードに設定するために。


モデルのバージョン管理には W&B Artifacts を使用する

W&B Artifactsバージョン管理付きで PyTorch モデルを保存・読み込みできる強力な方法を提供します。Artifact はバージョン管理されたディレクトリと考えることができ、モデルの各バージョンやその他のアセットを構造化して保存・追跡できます。
Artifacts の詳細ここ

PyTorch のモデルを W&B Artifacts として保存する

# Import
import wandb
# Save your model.
torch.save(model.state_dict(), 'save/to/path/model.pth')
# Save as artifact for version control.
run = wandb.init(project='your-project-name')
artifact = wandb.Artifact('model', type='model')
artifact.add_file('save/to/path/model.pth')
run.log_artifact(artifact)
run.finish()

W&B Artifacts として PyTorch モデルを読み込む

これで保存済みモデルがダウンロードされます。続いて、次の方法でモデルを読み込めます。 torch.load
import wandb
run = wandb.init()

artifact = run.use_artifact('entity/your-project-name/model:v0', type='model')
artifact_dir = artifact.download()

run.finish()

図1ライブの Artifacts ダッシュボードを確認してくださいここ

W&B Artifacts を使えば、PyTorch モデルの保存と読み込みがシームレスになり、バージョン管理が容易になります。これにより、プロジェクト内でモデルの反復版を簡単に管理できます。

Weights & Biases を試す

Weights & Biases は、機械学習実験の記録と管理を支援します。ハイパーパラメータや実行時の出力メトリクスをログに記録し、結果を可視化・比較して、得られた知見を同僚とすばやく共有できます。
はじめに5分で、またはReplitで手早く2つの実験を実行して、W&Bが作業の整理にどう役立つかを確認しましょう。以下の手順に従ってください。
手順
  1. 下の緑色の「Run」ボタンをクリックしてください初回に「Run」をクリックすると、Replit がマシンを割り当てるまでに約 30〜45 秒かかります。
  2. ターミナルウィンドウの指示に従ってください本文では、PyTorch のモデルを保存・読み込みする実践的な手法を、Weights & Biases(W&B)でのバージョン管理を交えながら解説します。推奨は state_dict(学習可能なパラメータと登録済みバッファを含む Python の辞書)での保存です。ファイルが小さく、柔軟で(オプティマイザの state_dict やハイパーパラメータ、メタデータも含められる)、Python の pickle により環境間で可搬性が高いからです。state_dict の保存・読み込み方法を示し、再読み込みには同一のモデルクラス定義が必要である点、推論前には model.eval()、学習再開や微調整(どちらかに統一して表記する場合は「微調整」)前には model.train() を必ず呼び出す点を確認します。 また、モデル全体(アーキテクチャ+重み)を単一オブジェクトとして保存する方法も説明します。これはコードが最も簡単ですが、元のクラス定義やディレクトリ構造に強く依存するため、プロジェクト間やリファクタリング後に脆くなることから非推奨です。 さらに、堅牢な学習ワークフローのためのチェックポイントも扱います。モデルとオプティマイザの state_dict に加え、エポック、損失、その他のメタデータを定期的に保存しておくことで、学習の再開や後での微調整をシームレスに行えます。 最後に、W&B Artifacts を用いたモデルのバージョン管理を紹介します。アーティファクトとしてモデルファイルを記録・バージョン管理・取得でき、実験やチームをまたいだ再現性と追跡可能性の高いモデル管理が可能です。重要な注意点として、ファイル拡張子は .pt または .pth を一貫して使用すること、読み込み後には用途に応じて model.eval() と model.train() を正しく設定することを挙げます。
  3. ターミナルウィンドウのサイズを変更できます(右下)をクリックして拡大表示



おすすめの読み物


Hongbo Miao
Hongbo Miao •  
I saw ONNX file format in the video "Integrate Weights & Biases with PyTorch" https://youtu.be/G7GH0SeNBMA?t=925 It would be great to add it. Thanks!
Reply
Arnav Das
Arnav Das •  
If for some reason Optimizer state wasn't saved, then the saved model state will be useless right ? or is there any way to that we can again retrain with only model state, loss and epoch information ?
3 replies