Comment sauvegarder et charger des modèles dans PyTorch
Dans ce tutoriel, vous apprendrez à sauvegarder et à charger correctement votre modèle formé dans PyTorch.
Created on February 2|Last edited on June 9
Comment
Introduction
La formation de modèles est coûteuse et prend beaucoup de temps pour les cas d'utilisation pratiques. L'enregistrement du modèle formé est généralement la dernière étape de la plupart des flux de travail de formation de modèles. Elle est suivie par la réutilisation de ces modèles pré-entraînés. Il existe plusieurs façons de sauvegarder et de charger un modèle formé dans PyTorch.
Ce court rapport examine les moyens de sauvegarder et de charger un modèle formé dans l'écosystème PyTorch. Pour des instructions détaillées, consultez la documentation officielle de PyTorch.
Sections
IntroductionSectionsUtiliser state_dict (recommandé)Sauvegarder et charger le modèle entierSauvegarder et charger le modèle à partir d'un point de contrôleUtiliser les Artefacts W&B pour le contrôle des versions du modèleWeights & BiasesLectures recommandées pour ceux qui s'intéressent à PyTorch
Utiliser state_dict (recommandé)
Un state_dict est simplement un dictionnaire Python qui relie chaque couche à ses tenseurs de paramètres. Les paramètres apprenables d'un modèle (couches convolutionnelles, couches linéaires, etc.) et les tampons enregistrés (running_mean de batchnorm) ont des entrées dans state_dict Plus d'informations sur state_dict regardez ici.
Sauvegarder
torch.save(model.state_dict(), 'save/to/path/model.pth')
Charger
model = MyModelDefinition(args)model.load_state_dict(torch.load('load/from/path/model.pth'))
Pour :
- PyTorch s'appuie en interne sur le module pickle de Python. Les dictionnaires Python peuvent facilement être picklés, dépicklés, mis à jour et restaurés. Ainsi, la sauvegarde du modèle en utilisant state_dict offre plus de flexibilité.
- Vous pouvez également sauvegarder l'état de l'optimiseur, les hyperparamètres, etc., sous forme de paires clé-valeur avec le state_dict du modèle. Une fois restaurés, vous pouvez y accéder comme à votre dictionnaire Python habituel. Nous verrons comment cela se passe dans la section suivante.
Contre :
- Vous aurez besoin de la définition du modèle pour charger le state_dict.
Des problèmes :
- Assurez-vous d'appeler model.eval() si vous voulez faire une inférence.
- Enregistrez le modèle en utilisant l'extension .pt ou .pth.
Sauvegarder et charger le modèle entier
Vous pouvez également sauvegarder le modèle entier dans PyTorch et pas seulement le state_dict. Cependant, ce n'est pas une façon recommandée de sauvegarder le modèle.
Sauvegarder
torch.save(model, 'save/to/path/model.pt')
Charger
model = torch.load('load/from/path/model.pt')
Pour :
- Le moyen le plus simple de sauvegarder l'ensemble du modèle avec le moins de code possible.
- L'API de sauvegarde et de chargement est plus intuitive.
Contre :
- Comme le module Pickle de Python est utilisé en interne, les données sérialisées sont liées aux classes spécifiques et la structure exacte du répertoire est utilisée lors de l'enregistrement du modèle. Pickle enregistre simplement un chemin vers le fichier contenant la classe spécifique. Ceci est utilisé au moment du chargement.
- Comme vous pouvez l'imaginer, le code peut se casser après le refactoring car le modèle sauvegardé peut ne pas être lié au même chemin. L'utilisation d'un tel modèle dans un autre projet est également difficile car la structure du chemin doit être maintenue.
-
Des problèmes :
- Assurez-vous d'appeler model.eval() si vous voulez faire de l'inférence.
- Enregistrez le modèle en utilisant l'extension .pt ou .pth.
Sauvegarder et charger le modèle à partir d'un point de contrôle
La sauvegarde et le chargement d'un modèle général de point de contrôle pour l'inférence ou la reprise de la formation peuvent être utiles pour reprendre là où vous vous êtes arrêté. Le state_dict du modèle sauvegardé n'est pas suffisant dans le contexte du checkpoint. Vous devrez également sauvegarder le state_dict de l'optimiseur, ainsi que le dernier numéro d'époque, la perte, etc. Vous pourriez vouloir sauvegarder tout ce dont vous auriez besoin pour reprendre l'entraînement en utilisant un point de contrôle.
Sauvegarder
torch.save({'epoch': EPOCH,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': LOSS,}, 'save/to/path/model.pth')
Charger
model = MyModelDefinition(args)optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)checkpoint = torch.load('load/from/path/model.pth')model.load_state_dict(checkpoint['model_state_dict'])optimizer.load_state_dict(checkpoint['optimizer_state_dict'])epoch = checkpoint['epoch']loss = checkpoint['loss']
Des problèmes :
- Pour l'utiliser pour la formation, appelez model.train().
- Pour l'utiliser pour l'inférence, appelez model.eval().
Utiliser les Artefacts W&B pour le contrôle des versions du modèle
Essayez le cahier Colab, pour enregistrer le modèle en tant qu'artefacts de Weights & Biases et consommez les mêmes.
Les artefacts W&B peuvent être utilisés pour stocker et conserver la trace des ensembles de données, des modèles et des résultats d'évaluation dans les pipelines d'apprentissage automatique. Pensez à un artefact comme à un dossier de données versionné. Plus d'informations sur les artefacts ici.
Enregistrer en tant qu'artefacts
# Importimport wandb# Save your model.torch.save(model.state_dict(), 'save/to/path/model.pth')# Save as artifact for version control.run = wandb.init(project='your-project-name')artifact = wandb.Artifact('model', type='model')artifact.add_file('save/to/path/model.pth')run.log_artifact(artifact)run.join()
Charger en tant qu'artefacts
Cela permettra de télécharger le modèle sauvegardé. Vous pouvez ensuite charger le modèle en utilisant torch.load.
import wandbrun = wandb.init()artifact = run.use_artifact('entity/your-project-name/model:v0', type='model')artifact_dir = artifact.download()run.join()

Weights & Biases
Weights & Biases vous aide à garder la trace de vos expériences d'apprentissage automatique. Utilisez notre outil pour enregistrer les hyperparamètres et les métriques de sortie de vos exécutions, puis visualisez et comparez les résultats et partagez rapidement vos conclusions avec vos collègues.
Lectures recommandées pour ceux qui s'intéressent à PyTorch
How To Use GPU with PyTorch
A short tutorial on using GPUs for your deep learning models with PyTorch, from checking availability to visualizing usable.
PyTorch Dropout for regularization - tutorial
Learn how to regularize your PyTorch model with Dropout, complete with a code tutorial and interactive visualizations
Image Classification Using PyTorch Lightning and Weights & Biases
This article provides a practical introduction on how to use PyTorch Lightning to improve the readability and reproducibility of your PyTorch code.
Transfer Learning Using PyTorch Lightning
In this article, we have a brief introduction to transfer learning using PyTorch Lightning, building on the image classification example from a previous article.
Add a comment
Cons Contre
1 reply
Iterate on AI agents and models faster. Try Weights & Biases today.