PyTorchでモデルを保存およびロードする方法
このチュートリアルでは、トレーンしたモデルをPyTorchに正しく保存して読み込む方法を学習します。
Created on January 30|Last edited on January 30
Comment
初めに
モデルのトレーニングは費用と時間がかかり、実際のユースケースには多くの時間がかかります。 トレーニング済みモデルの保存は、通常、ほとんどのモデルトレーニングワークフローの最後のステップです。 その後、これらの事前トレーニング済みモデルを再利用します。 PyTorchでトレーニング済みモデルを保存およびロードする方法はいくつかあります。
初めにstate_dictを使用する(推奨)モデル全体を保存してロードするチェックポイントからモデルを保存してロードするモデルバージョン管理にW&B Artifacts使用するWeights & BiasesPyTorchに興味のある方におすすめ
state_dictを使用する(推奨)
state_dict は、各レイヤーをそのパラメーターテンソルにマップする単純なPythonディクショナリです。 モデルの学習可能なパラメーター(たたみ込み層、線形層など)と登録されたバッファー(batchnormのrunning_mean)には、state_dict にエントリがあります。 state_dict の詳細については、こちらをご覧ください。
保存
torch.save(model.state_dict(), 'save/to/path/model.pth')
ロード
model = MyModelDefinition(args)model.load_state_dict(torch.load('load/from/path/model.pth'))
長所:
- PyTorchは、Pythonのpickleモジュールに内部的に依存しています。 Pythonディクショナリは、簡単にピクルス、アンピクル、更新、および復元できます。 したがって、state_dict を使用してモデルを保存すると、柔軟性が高まります。
- オプティマイザーの状態、ハイパーパラメーターなどを、モデルのstate_dictとともにキーと値のペアとして保存することもできます。 復元すると、通常のPython辞書と同じようにアクセスできます。 それがどのように行われるかについては、後のセクションで説明します。
短所:
- state_dict をロードするには、モデル定義が必要です。
要注意:
- 推論を行う場合は、必ずmodel.eval() を呼び出してください。
- .pt または.pth 拡張子を使用してモデルを保存します。
モデル全体を保存してロードする
`state_dictだけでなく、モデル全体をPyTorchに保存することもできます。 ただし、これはモデルを保存するための推奨される方法ではありません。
保存
torch.save(model, 'save/to/path/model.pt')
読み込み
model = torch.load('load/from/path/model.pt')
長所:
- 最小限のコードでモデル全体を保存する最も簡単な方法。
- APIの保存と読み込みはより直感的です。
短所:
- Pythonのpickleモジュールは内部で使用されるため、シリアル化されたデータは特定のクラスにバインドされ、モデルの保存時に正確なディレクトリ構造が使用されます。 Pickleは、特定のクラスを含むファイルへのパスを保存するだけです。 これは、読み込み時に使用されます。
- ご想像のとおり、保存されたモデルが同じパスにリンクしていない可能性があるため、リファクタリング後にコードが破損する可能性があります。 パス構造を維持する必要があるため、このようなモデルを別のプロジェクトで使用することも困難です。
要注意:
- 推論を行う場合は、必ずmodel.eval() を呼び出してください。
- .ptまたは.pth 拡張子を使用してモデルを保存します
チェックポイントからモデルを保存してロードする
推論またはトレーニングの再開のために一般的なチェックポイントモデルを保存およびロードすると、最後に中断したところから再開するのに役立ちます。 チェックポイントのコンテキストでは、モデルのstate_dictを保存するだけでは不十分です。 また、オプティマイザーのstate_dictを、最後のエポック番号、損失などとともに保存する必要があります。チェックポイントを使用してトレーニングを再開するために必要なすべてのものを保存することをお勧めします。
保存
torch.save({'epoch': EPOCH,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': LOSS,}, 'save/to/path/model.pth')
読み込み
model = MyModelDefinition(args)optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)checkpoint = torch.load('load/from/path/model.pth')model.load_state_dict(checkpoint['model_state_dict'])optimizer.load_state_dict(checkpoint['optimizer_state_dict'])epoch = checkpoint['epoch']loss = checkpoint['loss']
注意:
- トレーニングに使用するには、model.train() を呼び出します。
- 推論に使用するには、model.train() を呼び出します。
モデルバージョン管理にW&B Artifacts使用する
W&Bアーティファクトを使用して、機械学習パイプライン全体でデータセット、モデル、評価結果を保存および追跡できます。 アーティファクトは、バージョン管理されたデータのフォルダーと考えてください。 アーティファクトの詳細については、こちらをご覧ください。
アーティファクトとして保存する
# Importimport wandb# Save your model.torch.save(model.state_dict(), 'save/to/path/model.pth')# Save as artifact for version control.run = wandb.init(project='your-project-name')artifact = wandb.Artifact('model', type='model')artifact.add_file('save/to/path/model.pth')run.log_artifact(artifact)run.join()
アーティファクトとして読み込みする
保存したモデルがダウンロードされます。 次に、torch.load を使用してモデルを読み込みます。
import wandbrun = wandb.init()artifact = run.use_artifact('entity/your-project-name/model:v0', type='model')artifact_dir = artifact.download()run.join()
Weights & Biases
Weights&Biasesは、機械学習の実験を追跡するのに役立ちます。 当社のツールで、ハイパーパラメータをログに記録・メトリックを出力して、結果を視覚化して比較すれば、結果を同僚とすばやく共有することができます。
PyTorchに興味のある方におすすめ
How To Use GPU with PyTorch
A short tutorial on using GPUs for your deep learning models with PyTorch, from checking availability to visualizing usable.
PyTorch Dropout for regularization - tutorial
Learn how to regularize your PyTorch model with Dropout, complete with a code tutorial and interactive visualizations
Image Classification Using PyTorch Lightning and Weights & Biases
This article provides a practical introduction on how to use PyTorch Lightning to improve the readability and reproducibility of your PyTorch code.
Transfer Learning Using PyTorch Lightning
In this article, we have a brief introduction to transfer learning using PyTorch Lightning, building on the image classification example from a previous article.
Add a comment
Iterate on AI agents and models faster. Try Weights & Biases today.