Skip to main content

How to save and load models in PyTorch

This article is a machine learning tutorial on how to save and load your models in PyTorch using Weights & Biases for version control.
Created on November 19|Last edited on November 5
Model training is expensive and takes a lot of time for practical use cases. Saving the trained model is usually the last step for most ML workflows, followed by reusing them for inference.
There are several ways of saving and loading a trained model in PyTorch. In this tutorial, we will look at some of the ways to save and load a trained model in PyTorch. For detailed instructions, check out the official PyTorch documentation. 

Table of contents




Save and load PyTorch models with state_dict (recommended)

A state_dict is simply a Python dictionary that maps each layer to its parameter tensors. The learnable parameters of a model (convolutional layers, linear layers, etc.) and registered buffers (BatchNorm’s running_mean) have entries in state_dict.
Using state_dict is the recommended approach for several reasons, including:
  • First, state_dict stores only the essential parameters of the model (such as the weights and biases), which keeps file sizes smaller and allows for easy manipulation.Second, it offers flexibility—since state_dict is a Python dictionary, you can save not only model parameters but also optimizer states and other metadata, making it easier to resume training or fine-tune models. Additionally, state_dict simplifies the process of reloading models across different environments, as it is compatible with Python’s built-in pickle module for straightforward saving and loading. This guide will take you through the steps to save and load models in PyTorch using state_dict and explain best practices for effective model management.
  • Second, since state_dict is a Python dictionary, you can save not only model parameters but also optimizer states and other metadata, making it easier to resume training or fine-tune models.
  • Additionally, state_dict simplifies the process of reloading models across different environments, as it is compatible with Python’s built-in pickle module for straightforward saving and loading. This guide will take you through the steps to save and load models in PyTorch using state_dict and explain best practices for effective model management.
You can learn more about state_dict here.
Here's the code to save and load your models in PyTorch using state_dict:

Save your model in PyTorch using state_dict

torch.save(model.state_dict(), 'save/to/path/model.pth')

Load your model in PyTorch using state_dict

model = MyModelDefinition(args)
model.load_state_dict(torch.load('load/from/path/model.pth'))

Pros

  • PyTorch relies on Python’s pickle module, which allows Python dictionaries to be easily picked, unpickled, updated, and restored. Using state_dict provides flexibility in managing model parameters.
  • Along with model parameters, you can save additional elements like optimizer states and hyperparameters as key-value pairs in the state_dict. These can be easily accessed when reloading the model.

Cons:

  • To load a state_dict, you’ll need the exact model definition. Without it, loading the saved parameters will not work correctly.

Gotchas:

  • Make sure to call model.eval() when making inferences to set layers like dropout and batch normalization to evaluation mode.
  • Save the model using .pt or .pth extension for consistency and compatibility.

Save and load the entire PyTorch model

When saving and loading models in PyTorch, you have the option to save the entire model rather than just the state_dict. This approach captures the complete model architecture and parameters in one step, making it quick and easy to restore with minimal code.
While convenient, saving the full model is generally not recommended. Since it relies on Python's pickle module, the saved model file is tightly bound to the specific class definitions and directory structure used at the time of saving. This dependency can lead to compatibility issues if the code is refactored or if you want to use the model in a different project.
Let's look at to save and load your PyTorch models in this manner.

Save the entire PyTorch model

torch.save(model, 'save/to/path/model.pt')

Load the entire PyTorch model

model = torch.load('load/from/path/model.pt')

Pros:

  • Saving the entire model requires the least amount of code, making it a quick option for saving and loading.
  • The API for saving and loading the full model is straightforward and easy to use.The saving and loading API is more intuitive.

Cons:

  • Since Python’s pickle module is used internally, the saved model is tightly bound to specific class definitions and the directory structure. This means that any refactoring or changes to the file paths can cause loading issues.
  • Using the full model save in a different project is challenging, as the directory structure and class paths must be identical to those used when saving.

Gotchas:

  • Make sure to call model.eval() before making inferences to ensure proper behavior in layers like dropout and batch normalization.
  • Save the model using .pt or .pth extension when saving your model to maintain consistency and readability.

Save and load your PyTorch model from a checkpoint

In most machine learning pipelines, saving model checkpoints periodically or based on certain conditions is essential. This practice allows you to resume training from the latest or best checkpoint, ensuring continuity in case of interruptions. Checkpoints are also useful for fine-tuning and evaluating model performance at different stages.
When saving a checkpoint, saving only the model’s state_dict is not sufficient. You should also save the optimizer’s state_dict, the last epoch number, the current loss, and any other relevant information needed to seamlessly resume training.

Save a PyTorch model checkpoint

torch.save({'epoch': EPOCH,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': LOSS},
'save/to/path/model.pth')

Load a PyTorch model checkpoint

model = MyModelDefinition(args)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

checkpoint = torch.load('load/from/path/model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

Gotchas:

  • When resuming training, remember to call model.train() to ensure the model is set to training mode.
  • For inference after loading a checkpoint, call model.eval() to set the model to evaluation mode.


Use W&B Artifacts for model version control

W&B Artifacts provide a powerful way to save and load PyTorch models with version control. An artifact can be thought of as a versioned directory, allowing you to store and track model versions and other assets in a structured way
More on artifacts here.

Save a PyTorch model as a W&B Artifacts

# Import
import wandb
# Save your model.
torch.save(model.state_dict(), 'save/to/path/model.pth')
# Save as artifact for version control.
run = wandb.init(project='your-project-name')
artifact = wandb.Artifact('model', type='model')
artifact.add_file('save/to/path/model.pth')
run.log_artifact(artifact)
run.finish()

Load a PyTorch model as a W&B Artifacts

This will download the saved model. You can then load the model using torch.load.
import wandb
run = wandb.init()

artifact = run.use_artifact('entity/your-project-name/model:v0', type='model')
artifact_dir = artifact.download()

run.finish()

Figure 1: Check out the live artifact dashboard here

With W&B Artifacts, saving and loading PyTorch models becomes seamless, providing easy version control and making it simple to manage model iterations in your project.

Try Weights & Biases

Weights & Biases helps you keep track of your machine learning experiments. Try our tool to log hyperparameters and output metrics from your runs, then visualize and compare results and quickly share findings with your colleagues.
Get started in 5 minutes or run 2 quick experiments on Replit and see how W&B can help organise your work follow the instructions below:
Instructions:
  1. Click the green "Run" button below (the first time you click Run, Replit will take approx 30-45 seconds to allocate a machine)
  2. Follow the prompts in the terminal window (the bottom right pane below)
  3. You can resize the terminal window (bottom right) for a larger view



Hongbo Miao
Hongbo Miao •  
I saw ONNX file format in the video "Integrate Weights & Biases with PyTorch" https://youtu.be/G7GH0SeNBMA?t=925 It would be great to add it. Thanks!
Reply
Arnav Das
Arnav Das •  
If for some reason Optimizer state wasn't saved, then the saved model state will be useless right ? or is there any way to that we can again retrain with only model state, loss and epoch information ?
3 replies
Iterate on AI agents and models faster. Try Weights & Biases today.