Skip to main content

Wie man Modelle in PyTorch speichert und lädt

In diesem Tutorium lernen Sie, wie Sie Ihr trainiertes Modell in PyTorch korrekt speichern und laden.
Created on February 2|Last edited on February 2
Dieser Artikel wurde aus dem Englischen übersetzt. Hier ist das Original.

Einführung

Das Trainieren von Modellen ist teuer und nimmt in praktischen Anwendungsfällen viel Zeit in Anspruch. Das Speichern des trainierten Modells ist in der Regel der letzte Schritt der meisten Modelltrainings-Workflows. Danach folgt die Wiederverwendung dieser trainierten Modelle. Es gibt mehrere Möglichkeiten, ein trainiertes Modell in PyTorch zu speichern und zu laden.
Dieser kurze Bericht befasst sich mit den Möglichkeiten zum Speichern und Laden eines trainierten Modells im PyTorch-Ökosystem. Detaillierte Anweisungen finden Sie in der offiziellen PyTorch-Dokumentation.

Sektionen




Verwenden Sie state_dict (empfohlen)

Ein state_dict ist einfach ein Python-Verzeichnis, das jede Schicht auf ihre Parameter-Tensoren abbildet. Die lernbaren Parameter eines Modells (Faltungsschichten, lineare Schichten usw.) und registrierte Puffer (batchnorm's running_mean) haben Einträge in state_dict. Mehr über state_dict finden Sie hier.

Speichern

torch.save(model.state_dict(), 'save/to/path/model.pth')

Laden

model = MyModelDefinition(args)
model.load_state_dict(torch.load('load/from/path/model.pth'))

Vorteile:

  • PyTorch stützt sich intern auf das Pickle-Modul von Python. Das Python-Verzeichnis kann leicht gepickelt, entpickelt, aktualisiert und wiederhergestellt werden. Das Speichern von Modellen mit state_dict bietet daher mehr Flexibilität.
  • Sie können auch den Optimierungsstatus, die Hyperparameter usw. als Schlüssel-Wert-Paare zusammen mit dem state_dict des Modells speichern. Nach der Wiederherstellung können Sie auf diese Paare wie auf ein gewöhnliches Python-Verzeichnis zugreifen. Wie das geht, werden wir im nächsten Abschnitt sehen.

Nachteile:

  • Sie benötigen die Modelldefinition, um das state_dict zu laden.

Tücken:

  • Stellen Sie sicher, dass Sie model.eval() aufrufen, wenn Sie eine Inferenz machen wollen.
  • Speichern Sie das Modell mit der Erweiterung .pt oder .pth.

Speichern und Laden des gesamten Modells

Sie können auch das gesamte Modell in PyTorch speichern und nicht nur das `state_dict. Dies ist jedoch keine empfohlene Methode zum Speichern des Modells.

Speichern

torch.save(model, 'save/to/path/model.pt')

Laden

model = torch.load('load/from/path/model.pt')


Vorteile:

  • Der einfachste Weg, das gesamte Modell mit dem geringsten Codeaufwand zu speichern.
  • Die API zum Speichern und Laden ist intuitiver.

Nachteile:

  • Da das Pickle-Modul von Python intern verwendet wird, sind die serialisierten Daten an die spezifischen Klassen gebunden und die genaue Verzeichnisstruktur wird beim Speichern des Modells verwendet. Pickle speichert einfach einen Pfad zu der Datei, welche die jeweilige Klasse enthält. Dies wird beim Laden verwendet.
  • Sie können sich vorstellen, dass der Code nach dem Refactoring nicht mehr funktioniert, da das gespeicherte Modell möglicherweise nicht mit demselben Pfad verknüpft ist. Die Verwendung eines solchen Modells in einem anderen Projekt ist ebenfalls schwierig, da die Pfadstruktur beibehalten werden muss.

Tücken:

  • Stellen Sie sicher, dass Sie model.eval() aufrufen, wenn Sie Inferenzen durchführen wollen.
  • Speichern Sie das Modell mit der Erweiterung .pt oder .pth.

Speichern und Laden des Modells von einem Checkpoint aus

Das Speichern und Laden eines allgemeinen Checkpoint-Modells zur Inferenz oder zur Wiederaufnahme des Trainings kann hilfreich sein, um dort weiterzumachen, wo man zuletzt aufgehört hat. Das Speichern des state_dict des Modells reicht im Zusammenhang mit dem Checkpoint nicht aus. Sie müssen auch das state_dict des Optimierers speichern, zusammen mit der letzten Epochenzahl, dem Verlust usw. Sie sollten alles speichern, was Sie benötigen, um das Training unter Verwendung eines Checkpoints fortzusetzen.

Speichern

torch.save({
'epoch': EPOCH,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': LOSS,
}, 'save/to/path/model.pth')

Laden

model = MyModelDefinition(args)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

checkpoint = torch.load('load/from/path/model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

Tücken:

  • Um dies für das Training zu verwenden, rufen Sie model.train() auf.
  • Um dies für die Inferenz zu verwenden, rufen Sie model.eval() auf.

Verwendung von W&B-Artefakten für die Modellversionskontrolle


Probieren Sie das Colab-Notizbuch aus, um das Modell als Weights & Biases Artefakte zu speichern und zu verbrauchen.
W&B Artefakte können zum Speichern und Nachverfolgen von Datensätzen, Modellen und Evaluierungsergebnissen in Pipelines für maschinelles Lernen verwendet werden. Stellen Sie sich ein Artefakt als einen versionierten Ordner mit Daten vor. Mehr über Artefakte erfahren Sie hier.

Als Artefakte speichern

# Import
import wandb
# Save your model.
torch.save(model.state_dict(), 'save/to/path/model.pth')
# Save as artifact for version control.
run = wandb.init(project='your-project-name')
artifact = wandb.Artifact('model', type='model')
artifact.add_file('save/to/path/model.pth')
run.log_artifact(artifact)
run.join()

Als Artefakte laden

Dadurch wird das gespeicherte Modell heruntergeladen. Sie können das Modell dann mit torch.load laden.
import wandb
run = wandb.init()

artifact = run.use_artifact('entity/your-project-name/model:v0', type='model')
artifact_dir = artifact.download()

run.join()

Abbildung 1: Sehen Sie sich hier das Live-Artefakt-Dashboard an


Weights & Biases

Weights & Biases hilft Ihnen, den Überblick über Ihre Experimente zum maschinellen Lernen zu behalten. Nutzen Sie unser Tool, um Hyperparameter und Ausgabemetriken aus Ihren Durchläufen zu protokollieren, die Ergebnisse zu visualisieren und zu vergleichen und die Erkenntnisse schnell mit Ihren Kollegen zu teilen.
Starten Sie in 5 Minuten.

Empfohlene Lektüre für alle, die sich für PyTorch interessieren


Iterate on AI agents and models faster. Try Weights & Biases today.