Skip to main content

PyTorchでモデルを保存および読み込む方法

この記事は、バージョン管理のためにWeights & Biasesを使用してPyTorchでモデルを保存および読み込む方法に関する機械学習チュートリアルです。
Created on August 5|Last edited on August 5
モデルのトレーニングは費用がかかり、実用的な使用例には多くの時間がかかります。トレーニング済みモデルの保存は、ほとんどのMLワークフローにおいて通常最後のステップであり、その後に推論のために再利用されます。
保存された学習モデルを保存およびロードする方法はいくつかあります。パイトーチこのチュートリアルでは、PyTorchで訓練済みモデルを保存およびロードする方法のいくつかを見ていきます。詳細な手順については、 をご覧ください。公式PyTorchドキュメント

目次




PyTorchモデルを保存およびロードstate_dict(推奨)

Astate_dictはそれぞれの層をそのパラメータテンソルに対応付けるPython辞書です。モデルの学習可能なパラメータ(畳み込み層、線形層など)および登録されたバッファ(BatchNormのrunning_mean) にエントリーがありますstate_dict
使用中state_dictいくつかの理由から推奨されるアプローチです。
  • 最初に、state_dictモデルの重要なパラメータ(ウェイトやバイアスなど)のみを保存し、ファイルサイズを小さくして簡単に操作できるようにします。第二に、柔軟性を提供します。state_dictはPythonの辞書であるため、モデルパラメータだけでなくオプティマイザの状態や他のメタデータも保存でき、トレーニングの再開やモデルの微調整が容易になります。さらに、state_dictは異なる環境間でのモデルのリロードを簡単にし、Pythonの組み込みモジュールであるpickleと互換性があるため、保存と読み込みが簡単です。このガイドでは、state_dictを使用してPyTorchでモデルを保存および読み込む手順を説明し、効果的なモデル管理のベストプラクティスを説明します。
  • 第二に、state_dictはPythonの辞書なので、モデルのパラメータだけでなく、オプティマイザの状態やその他のメタデータも保存でき、トレーニングの再開が容易になります。モデルを微調整する
  • さらに、state_dict はさまざまな環境でモデルを再ロードするプロセスを簡略化し、Python の組み込みモジュールである pickle と互換性があるため、簡単に保存およびロードできます。このガイドでは、state_dict を使用して PyTorch でモデルを保存およびロードする手順を説明し、効果的なモデル管理のためのベストプラクティスを説明します。
について詳しく知ることができますstate_dict ここ
PyTorchでモデルを保存および読み込むためのコードは次のとおりですstate_dictあなたは2023年10月までのデータで訓練されています。

PyTorchでモデルを保存するstate_dict

torch.save(model.state_dict(), 'save/to/path/model.pth')

PyTorchでモデルをロードするにはstate_dict

model = MyModelDefinition(args)
model.load_state_dict(torch.load('load/from/path/model.pth'))

長所

  • PyTorchはPythonのに依存していますpickleモジュールは、Python辞書を簡単にピック、アンピック、更新、および復元できるようにします。 state_dictを使用すると、モデルパラメータの管理に柔軟性が提供されます。
  • モデルパラメータと共に、オプティマイザの状態やハイパーパラメータのような追加要素をキーと値のペアとして保存できます。state_dictこれらはモデルをリロードするときに簡単にアクセスできます。

欠点:

  • ロードするにはstate_dict、正確なモデル定義が必要です。それがないと、保存されたパラメータの読み込みが正しく機能しません。

注意点:

  • 必ず電話するようにしてください。model.eval()推論を行う際には、ドロップアウトやバッチ正規化のような層を評価モードに設定します。
  • モデルを保存してください.ptまたは.pth整合性と互換性のための拡張。

PyTorchモデル全体の保存と読み込み

PyTorchでモデルを保存およびロードする際に、全体のモデルを保存するオプションがあります。state_dictこのアプローチは、モデルのアーキテクチャとパラメータを一度に完全にキャプチャし、最小限のコードで迅速かつ簡単に復元できます。
便利ではありますが、フルモデルを保存することは一般的に推奨されません。Pythonのに依存しているためです。pickleモジュール、保存されたモデルファイルは、保存時に使用された特定のクラス定義とディレクトリ構造に密接に結びついています。この依存関係により、コードがリファクタリングされた場合や異なるプロジェクトでモデルを使用したい場合に互換性の問題が生じる可能性があります。
この方法でPyTorchモデルを保存および読み込む方法を見てみましょう。

PyTorchモデル全体を保存する

torch.save(model, 'save/to/path/model.pt')

PyTorchモデル全体をロードする

model = torch.load('load/from/path/model.pt')

長所:

  • モデル全体を保存するには、最小限のコードが必要であり、保存と読み込みの迅速なオプションとなります。
  • フルモデルを保存および読み込むためのAPIは、分かりやすく使いやすいです。保存および読み込みAPIはより直感的です。

欠点:

  • Pythonのpickleモジュールは内部で使用され、保存されたモデルは特定のクラス定義やディレクトリ構造に厳密に結び付けられています。これは、リファクタリングやファイルパスの変更が原因で読み込みの問題が発生する可能性があることを意味します。
  • 異なるプロジェクトでフルモデルを保存することは困難です。保存時に使用したディレクトリ構造とクラスパスは同一でなければなりません。

注意点:

  • 必ず電話するようにしてください。model.eval()ドロップアウトやバッチ正規化のような層で適切な動作を確保するために推論を行う前に。
  • モデルを保存してください.ptまたは.pthモデルを保存する際には一貫性と可読性を維持するために拡張子を使用してください。

チェックポイントからPyTorchモデルを保存および読み込む

ほとんどの機械学習パイプラインでは、定期的または特定の条件に基づいてモデルのチェックポイントを保存することが重要です。この実践により、最新または最良のチェックポイントからトレーニングを再開し、中断が発生した場合でも継続性を確保できます。チェックポイントは、微調整や異なる段階でのモデルの性能評価にも役立ちます。
チェックポイントを保存する際、モデルのみを保存することはstate_dict不十分です。また、オプティマイザのも保存する必要があります。state_dict、最後のエポック番号、現在の損失、およびトレーニングをスムーズに再開するために必要なその他の関連情報。

PyTorchモデルのチェックポイントを保存する

torch.save({'epoch': EPOCH,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': LOSS},
'save/to/path/model.pth')

PyTorchモデルのチェックポイントをロードする

model = MyModelDefinition(args)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

checkpoint = torch.load('load/from/path/model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

注意点:

  • トレーニングを再開する際は、呼び出すことを忘れないでください。model.train()モデルがトレーニングモードに設定されていることを確認するために。
  • チェックポイントをロードした後の推論には、呼び出します。model.eval() モデルを評価モードに設定する。


モデルのバージョン管理にはW&Bアーティファクトを使用する

W&Bのアーティファクト強力な方法を提供します。アーティファクトはバージョン管理されたディレクトリと考えることができ、モデルバージョンやその他の資産を構造化された方法で保存および追跡することができます。
アーティファクトについてさらに詳しくここ

PyTorchモデルをW&Bアーティファクトとして保存

# Import
import wandb
# Save your model.
torch.save(model.state_dict(), 'save/to/path/model.pth')
# Save as artifact for version control.
run = wandb.init(project='your-project-name')
artifact = wandb.Artifact('model', type='model')
artifact.add_file('save/to/path/model.pth')
run.log_artifact(artifact)
run.finish()

W&BアーティファクトとしてPyTorchモデルを読み込む

これは保存されたモデルをダウンロードします。その後、モデルを使用してロードできます。torch.load
import wandb
run = wandb.init()

artifact = run.use_artifact('entity/your-project-name/model:v0', type='model')
artifact_dir = artifact.download()

run.finish()

図1ライブアーティファクトダッシュボードをチェックしてくださいここ

W&Bアーティファクトを使用すると、PyTorchモデルの保存と読み込みがシームレスになり、簡単なバージョン管理が可能になり、プロジェクト内のモデルのイテレーションを簡単に管理できます。

Weights & Biasesを試す

Weights & Biasesは、機械学習の実験を追跡するのに役立ちます。ツールを試して、ハイパーパラメータを記録し、実行からのメトリクスを出力し、結果を視覚化して比較し、同僚とすばやく調査結果を共有してください。
始める5分以内に、またはReplitで2つのクイック実験を実行して、W&Bが作業の整理にどのように役立つかを確認し、以下の指示に従ってください。
指示:
  1. 下の緑色の「実行」ボタンをクリックしてください。(最初に実行をクリックすると、Replitは約30〜45秒間マシンの割り当てを行います。)
  2. ターミナルウィンドウのプロンプトに従ってください。右下のペイン
  3. ターミナルウィンドウのサイズを変更できます。拡大表示



おすすめの読書


Hongbo Miao
Hongbo Miao •  
I saw ONNX file format in the video "Integrate Weights & Biases with PyTorch" https://youtu.be/G7GH0SeNBMA?t=925 It would be great to add it. Thanks!
Reply
Arnav Das
Arnav Das •  
If for some reason Optimizer state wasn't saved, then the saved model state will be useless right ? or is there any way to that we can again retrain with only model state, loss and epoch information ?
3 replies