Skip to main content

PyTorchでモデルを保存およびロードする方法

このチュートリアルでは、トレーンしたモデルをPyTorchに正しく保存して読み込む方法を学習します。
Created on January 30|Last edited on January 30
このレポートは、Ayush Thakurによる「How to Save and Load Models in PyTorch」の翻訳です。

初めに

モデルのトレーニングは費用と時間がかかり、実際のユースケースには多くの時間がかかります。 トレーニング済みモデルの保存は、通常、ほとんどのモデルトレーニングワークフローの最後のステップです。 その後、これらの事前トレーニング済みモデルを再利用します。 PyTorchでトレーニング済みモデルを保存およびロードする方法はいくつかあります。
この短いレポートでは、PyTorchエコシステムでトレーニング済みモデルを保存およびロードする方法について説明します。 詳細な手順については、PyTorchの公式ドキュメントをご覧ください


state_dictを使用する(推奨)

state_dict は、各レイヤーをそのパラメーターテンソルにマップする単純なPythonディクショナリです。 モデルの学習可能なパラメーター(たたみ込み層、線形層など)と登録されたバッファー(batchnormのrunning_mean)には、state_dict にエントリがあります。 state_dict の詳細については、こちらをご覧ください。

保存

torch.save(model.state_dict(), 'save/to/path/model.pth')

ロード

model = MyModelDefinition(args)
model.load_state_dict(torch.load('load/from/path/model.pth'))

長所:

  • PyTorchは、Pythonのpickleモジュールに内部的に依存しています。 Pythonディクショナリは、簡単にピクルス、アンピクル、更新、および復元できます。 したがって、state_dict を使用してモデルを保存すると、柔軟性が高まります。
  • オプティマイザーの状態、ハイパーパラメーターなどを、モデルのstate_dictとともにキーと値のペアとして保存することもできます。 復元すると、通常のPython辞書と同じようにアクセスできます。 それがどのように行われるかについては、後のセクションで説明します。

短所:

  • state_dict をロードするには、モデル定義が必要です。

要注意:

  • 推論を行う場合は、必ずmodel.eval() を呼び出してください。
  • .pt または.pth 拡張子を使用してモデルを保存します。


モデル全体を保存してロードする

`state_dictだけでなく、モデル全体をPyTorchに保存することもできます。 ただし、これはモデルを保存するための推奨される方法ではありません。

保存

torch.save(model, 'save/to/path/model.pt')

読み込み

model = torch.load('load/from/path/model.pt')

長所:

  • 最小限のコードでモデル全体を保存する最も簡単な方法。
  • APIの保存と読み込みはより直感的です。

短所:

  • Pythonのpickleモジュールは内部で使用されるため、シリアル化されたデータは特定のクラスにバインドされ、モデルの保存時に正確なディレクトリ構造が使用されます。 Pickleは、特定のクラスを含むファイルへのパスを保存するだけです。 これは、読み込み時に使用されます。
  • ご想像のとおり、保存されたモデルが同じパスにリンクしていない可能性があるため、リファクタリング後にコードが破損する可能性があります。 パス構造を維持する必要があるため、このようなモデルを別のプロジェクトで使用することも困難です。

要注意:

  • 推論を行う場合は、必ずmodel.eval() を呼び出してください。
  • .ptまたは.pth 拡張子を使用してモデルを保存します

チェックポイントからモデルを保存してロードする

推論またはトレーニングの再開のために一般的なチェックポイントモデルを保存およびロードすると、最後に中断したところから再開するのに役立ちます。 チェックポイントのコンテキストでは、モデルのstate_dictを保存するだけでは不十分です。 また、オプティマイザーのstate_dictを、最後のエポック番号、損失などとともに保存する必要があります。チェックポイントを使用してトレーニングを再開するために必要なすべてのものを保存することをお勧めします。

保存

torch.save({
'epoch': EPOCH,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': LOSS,
}, 'save/to/path/model.pth')

読み込み

model = MyModelDefinition(args)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

checkpoint = torch.load('load/from/path/model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

注意:

  • トレーニングに使用するには、model.train() を呼び出します。
  • 推論に使用するには、model.train() を呼び出します。


モデルバージョン管理にW&B Artifacts使用する


モデルをWeights & Biases Artifactsとして保存し、同様に消費するためにColabノートブックを試してみましょう。
W&Bアーティファクトを使用して、機械学習パイプライン全体でデータセット、モデル、評価結果を保存および追跡できます。 アーティファクトは、バージョン管理されたデータのフォルダーと考えてください。 アーティファクトの詳細については、こちらをご覧ください。


アーティファクトとして保存する

# Import
import wandb
# Save your model.
torch.save(model.state_dict(), 'save/to/path/model.pth')
# Save as artifact for version control.
run = wandb.init(project='your-project-name')
artifact = wandb.Artifact('model', type='model')
artifact.add_file('save/to/path/model.pth')
run.log_artifact(artifact)
run.join()

アーティファクトとして読み込みする

保存したモデルがダウンロードされます。 次に、torch.load を使用してモデルを読み込みます。
import wandb
run = wandb.init()

artifact = run.use_artifact('entity/your-project-name/model:v0', type='model')
artifact_dir = artifact.download()

run.join()

図1:ライブアーティファクトダッシュボードをここで確認してください

Weights & Biases

Weights&Biasesは、機械学習の実験を追跡するのに役立ちます。 当社のツールで、ハイパーパラメータをログに記録・メトリックを出力して、結果を視覚化して比較すれば、結果を同僚とすばやく共有することができます。
5分で開始します

PyTorchに興味のある方におすすめ


Iterate on AI agents and models faster. Try Weights & Biases today.