Skip to main content

Comment sauvegarder et charger des modèles dans PyTorch

Dans ce tutoriel, vous apprendrez à sauvegarder et à charger correctement votre modèle formé dans PyTorch.
Created on February 2|Last edited on June 9
Ceci est une traduction d'un article en anglais qui peut être trouvé ici.



Introduction

La formation de modèles est coûteuse et prend beaucoup de temps pour les cas d'utilisation pratiques. L'enregistrement du modèle formé est généralement la dernière étape de la plupart des flux de travail de formation de modèles. Elle est suivie par la réutilisation de ces modèles pré-entraînés. Il existe plusieurs façons de sauvegarder et de charger un modèle formé dans PyTorch.
Ce court rapport examine les moyens de sauvegarder et de charger un modèle formé dans l'écosystème PyTorch. Pour des instructions détaillées, consultez la documentation officielle de PyTorch.

Sections



Utiliser state_dict (recommandé)

Un state_dict est simplement un dictionnaire Python qui relie chaque couche à ses tenseurs de paramètres. Les paramètres apprenables d'un modèle (couches convolutionnelles, couches linéaires, etc.) et les tampons enregistrés (running_mean de batchnorm) ont des entrées dans state_dict Plus d'informations sur state_dict regardez ici.

Sauvegarder

torch.save(model.state_dict(), 'save/to/path/model.pth')

Charger

model = MyModelDefinition(args)
model.load_state_dict(torch.load('load/from/path/model.pth'))

Pour :

  • PyTorch s'appuie en interne sur le module pickle de Python. Les dictionnaires Python peuvent facilement être picklés, dépicklés, mis à jour et restaurés. Ainsi, la sauvegarde du modèle en utilisant state_dict offre plus de flexibilité.
  • Vous pouvez également sauvegarder l'état de l'optimiseur, les hyperparamètres, etc., sous forme de paires clé-valeur avec le state_dict du modèle. Une fois restaurés, vous pouvez y accéder comme à votre dictionnaire Python habituel. Nous verrons comment cela se passe dans la section suivante.

Contre :

  • Vous aurez besoin de la définition du modèle pour charger le state_dict.

Des problèmes :

  • Assurez-vous d'appeler model.eval() si vous voulez faire une inférence.
  • Enregistrez le modèle en utilisant l'extension .pt ou .pth.

Sauvegarder et charger le modèle entier

Vous pouvez également sauvegarder le modèle entier dans PyTorch et pas seulement le state_dict. Cependant, ce n'est pas une façon recommandée de sauvegarder le modèle.

Sauvegarder

torch.save(model, 'save/to/path/model.pt')

Charger

model = torch.load('load/from/path/model.pt')

Pour :

  • Le moyen le plus simple de sauvegarder l'ensemble du modèle avec le moins de code possible.
  • L'API de sauvegarde et de chargement est plus intuitive.

Contre :

  • Comme le module Pickle de Python est utilisé en interne, les données sérialisées sont liées aux classes spécifiques et la structure exacte du répertoire est utilisée lors de l'enregistrement du modèle. Pickle enregistre simplement un chemin vers le fichier contenant la classe spécifique. Ceci est utilisé au moment du chargement.
  • Comme vous pouvez l'imaginer, le code peut se casser après le refactoring car le modèle sauvegardé peut ne pas être lié au même chemin. L'utilisation d'un tel modèle dans un autre projet est également difficile car la structure du chemin doit être maintenue.
  • 

Des problèmes :

  • Assurez-vous d'appeler model.eval() si vous voulez faire de l'inférence.
  • Enregistrez le modèle en utilisant l'extension .pt ou .pth.

Sauvegarder et charger le modèle à partir d'un point de contrôle

La sauvegarde et le chargement d'un modèle général de point de contrôle pour l'inférence ou la reprise de la formation peuvent être utiles pour reprendre là où vous vous êtes arrêté. Le state_dict du modèle sauvegardé n'est pas suffisant dans le contexte du checkpoint. Vous devrez également sauvegarder le state_dict de l'optimiseur, ainsi que le dernier numéro d'époque, la perte, etc. Vous pourriez vouloir sauvegarder tout ce dont vous auriez besoin pour reprendre l'entraînement en utilisant un point de contrôle.

Sauvegarder

torch.save({
'epoch': EPOCH,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': LOSS,
}, 'save/to/path/model.pth')

Charger

model = MyModelDefinition(args)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

checkpoint = torch.load('load/from/path/model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

Des problèmes :

  • Pour l'utiliser pour la formation, appelez model.train().
  • Pour l'utiliser pour l'inférence, appelez model.eval().

Utiliser les Artefacts W&B pour le contrôle des versions du modèle


Essayez le cahier Colab, pour enregistrer le modèle en tant qu'artefacts de Weights & Biases et consommez les mêmes.
Les artefacts W&B peuvent être utilisés pour stocker et conserver la trace des ensembles de données, des modèles et des résultats d'évaluation dans les pipelines d'apprentissage automatique. Pensez à un artefact comme à un dossier de données versionné. Plus d'informations sur les artefacts ici.

Enregistrer en tant qu'artefacts

# Import
import wandb
# Save your model.
torch.save(model.state_dict(), 'save/to/path/model.pth')
# Save as artifact for version control.
run = wandb.init(project='your-project-name')
artifact = wandb.Artifact('model', type='model')
artifact.add_file('save/to/path/model.pth')
run.log_artifact(artifact)
run.join()

Charger en tant qu'artefacts

Cela permettra de télécharger le modèle sauvegardé. Vous pouvez ensuite charger le modèle en utilisant torch.load.
import wandb
run = wandb.init()

artifact = run.use_artifact('entity/your-project-name/model:v0', type='model')
artifact_dir = artifact.download()

run.join()

Figure 1: Check out the live artifact dashboard here


Weights & Biases

Weights & Biases vous aide à garder la trace de vos expériences d'apprentissage automatique. Utilisez notre outil pour enregistrer les hyperparamètres et les métriques de sortie de vos exécutions, puis visualisez et comparez les résultats et partagez rapidement vos conclusions avec vos collègues.
Commencez en 5 minutes.

Lectures recommandées pour ceux qui s'intéressent à PyTorch


Léandre Ramos
Léandre Ramos •  
Cons Contre
1 reply
Iterate on AI agents and models faster. Try Weights & Biases today.