How to Checkpoint Models on Weights & Biases
A primer on how to use W&B's Model Registry to version and checkpoint your models
Created on September 21|Last edited on October 4
Comment
Introduction
Much like in standard software development, versioning in machine learning is a necessity. Broadly speaking, versioning facilitates enhanced governance, faster debugging, easy rollbacks, and better collaboration. In a machine learning context specifically, versioning lets us do all of the above, in addition to letting us reproduce models more easily, rescue crashed training runs, and compare model versions and training runs to each other.
While versioning your data is vital in ML, in this piece, we’ll be focusing on versioning our models (a.k.a. model checkpointing) in Weights & Biases. We’ll train a simple MNIST model for our example but the steps and best practices apply to any other model as well.
💡
We’ll make heavy use of W&B Artifacts to store versioned checkpoints our models and W&B Model Registry to keep promote, document, and share high performing checkpoints. You can see our docs for deeper detail about both features, but it’s likely not necessary for this introduction.
Okay. Let’s jump right in.
Note: You'll need to log to a W&B Team to use the Model Registry. You can learn how to create a team here. Once you have a Team created you can log to it by setting the entity parameter to your team name in wandb.init()
💡
Install Weights & Biases
First things first: you’ll need an active W&B account to get started. If you don’t have one, you can sign up here. Install W&B in a Python environment:
pip install wandb
Then log in in a notebook by running:
import wandbwandb.login()
Or if you're running in the command line interface:
wandb login
Train a model
At this point, it’s worth level-setting on when you want to actually checkpoint your models. While this is context and use case dependent, you’ll usually want to checkpoint your model as you train it. If, for example, you checkpoint the model at each epoch or every n steps, you’ll get two main benefits:
Easily find your best performing version: Models don’t necessarily keep improving epoch over epoch. You may find that your model had the highest validation accuracy (or any other relevant metric) on epoch 5 versus epoch 10 due to the model over-fitting to the training set. You can either keep iterating from there or, if you’re satisfied with your performance, use that model.
Resume training runs: Checkpointing your models during training lets you resume your training run if your training run crashes midway through or exhibits behavior you’re not particularly keen on. Think of this as reverting or rolling back to a moment in time during the training process.
Here, we’ll train a model for 10 epochs. First we start a W&B run and begin training our model, logging metrics as normal. We’ll also save a model checkpoint every epoch using W&B Artifacts. When we save it, we’ll also note the epoch number and accuracy for that checkpoint as additional aliases. This will help us find the most relevant checkpoint later.
import wandbwandb.init(entity="morg", project="my-llm-fine-tuning")n_epochs = 10for epoch in range(n_epochs):for batch in dataloader:… your model training code# log your train metrics to W&B every batchwandb.log(“train/loss”:train_loss)# log your validation metrics to W&B every epochwandb.log(“val/val_loss”:val_loss, “val/val_accuracy”:val_accuracy)# save a model checkpoint at the end of each epoch to W&Bmodel_artifact = wandb.Artifact(name="llama2-fine-tune",type="model",alias=[f”epoch:{epoch}”, f”val_accuracy:{accuracy}”])# Add your model weights file to the artifactmodel_artifact.add_file("model.pt")# log the Artifact to W&Bwandb.log_artifact(model_artifact,aliases=[f"epoch - {epoch+1}", f"val_accuracy - {accuracy}"])
At this point, our model’s been trained and we’ve logged it as a W&B Artifact at each epoch. Let’s take a look at which epoch had the best performance during our training run:
Run set
1
Here, we can see that our model actually had a higher validation accuracy at epoch 9 than epoch 10. We want to make sure we note that explicitly in our Model Registry. We’ll see how to do that next.
Organizing your models in W&B
When we’ve identified a model checkpoint that we’d like to elevate to the Model Registry (for example, our best performing model to date) we can link that model to the W&B Model Registry.
First, navigate to your projects Artifacts tab. From here you can find the artifact version for epoch 9 and click the “Usage” tab to get the Artifacts path to that checkpoint version that you’ve saved.

Then you can copy this path and use it to find the correct model version using the wandb API and then link it to a particular Model Collection.
import wandb# Fetch the Model Version via APImy_artifact = wandb.Api().artifact('morg/llama2-fine-tune/llama2-fine-tune:v49', type='model')# Link the Model Version to the Model Collectionmy_artifact.link("morg/llama2-fine-tune/chat-support-llama2")
Alternatively, you can also link an Artifact version to the Model Registry via the UI. First you select the Artifact version you want in your project’s Artifacts tab, then click the “Link to Registry” in the top right corner.

If you haven’t already created a Registered Model to add this artifact version to, you can create a new one when you click “Link to Registry” like so:

Now your best model checkpoint is registered in the Model Registry. To see how to use W&B Model Registry as a central system of record for your best models, all standardized and organized in a model registry across projects and teams, check out our docs.
How to resume training from a checkpoint
Say we want to download the latest versioned checkpoint of our model from Weights & Biases, you can do it by running the following:
import wandbwandb.init(entity="morg", project="my-llm-fine-tuning")# Fetch the Model Version via APImodel_artifact = wandb.Api().artifact('morg/llama2-fine-tune/llama2-fine-tune:v49', type='model')# Download the data to a directory of your choicemodel_artifact.download("models/llama2-ckpts")
Here, we’re using the model weights file we saved in the “llama2-fine-tune” Artifact, specifying we want version 49 and downloading that data to a directory called “models/llama2-ckpts”
If there is a different model checkpoint you’d like to use, you can specify a different Artifacts alias, i.e. by changing "v49" to “latest” or “v50”, to get the 50th version of the model.
Conclusion
Checkpointing your models is a widely considered a best practice is machine learning. W&B's Model Registry can help you do just that. It's also a great tool to organize big, complex projects, automate workflows, and a whole lot more. If you'd like to learn more about our Model Registry, our docs are a great place to start. And, if you'd like to get started with W&B, signing up is free and it takes just a couple minutes to get started.
Happy modeling and thanks for reading!
Add a comment