How to save all your trained model weights locally after every epoch
As part of this report, I am going to show you how to save model weights locally after every epoch during model training. As part of this report we will also look at a more scalable way of storing model weights.
Created on February 22|Last edited on April 7
Comment
IntroductionCheckpoint Saver in PyTorch How does it work?A working exampleSummary of saving models using Checkpoint SaverExperiment Tracking with Weights and BiasesA more scalable way of saving your model weights What are W&B Artifacts?Conclusion
Introduction
As machine learning engineers, we have to run hundreds of experiments to build robust generalizable models! This means, a few hundred epochs because in general, for each experiment, a model is trained for more than one epoch.
But generally, the model weights after the last epoch aren't always the best. Often, the best model weights are the ones where we have the lowest validation loss or the highest validation metric. This means, to get the best results, we need to be able to track model weights per epoch. This idea is commonly referred to as "Checkpoint saving".
It's not new and most libraries support it. For example, timm (or Pytorch-Image-Models) has it's own checkpoint saver, PyTorch Lightning also has a "Model Checkpoint" callback, fastai also has a "Save Model Callback".
Here's what goes on behind each of these callbacks - we keep track of validation loss or validation metric (like AUC score, Accuracy, etc) and if in the current epoch, the validation loss is lower (or metric value is higher) than the previous epoch, we save the model weights. Most of the libraries are flexible and allow users to save as many model checkpoints as the user likes. Usually, the default value is around 5 - which means that at any time we save the top-5 models during training. We could then simply ensemble these 5 models and make predictions on the test set!
As part of this report, we're going to build our own checkpoint saver from scratch in PyTorch.
Checkpoint Saver in PyTorch
Before we start writing some code, here are the main ideas:
- We need to be able to provide the user the option to select a certain number of models that users want to save.
- We need to create a valid name for the model for ease of access later on.
- Give users the ability to provide a directory where they want to save the model weights.
- Either save model weights based on highest validation metric scores or lowest validation loss.
Let's start with a simple CheckpointSaver that does the above.
import numpy as npimport osimport loggingclass CheckpointSaver:def __init__(self, dirpath, decreasing=True, top_n=5):"""dirpath: Directory path where to store all model weightsdecreasing: If decreasing is `True`, then lower metric is bettertop_n: Total number of models to track based on validation metric value"""if not os.path.exists(dirpath): os.makedirs(dirpath)self.dirpath = dirpathself.top_n = top_nself.decreasing = decreasingself.top_model_paths = []self.best_metric_val = np.Inf if decreasing else -np.Infdef __call__(self, model, epoch, metric_val):model_path = os.path.join(self.dirpath, model.__class__.__name__ + f'_epoch{epoch}.pt')save = metric_val<self.best_metric_val if self.decreasing else metric_val>self.best_metric_valif save:logging.info(f"Current metric value better than {metric_val} better than best {self.best_metric_val}, saving model at {model_path}")self.best_metric_val = metric_valtorch.save(model.state_dict(), model_path)self.top_model_paths.append({'path': model_path, 'score': metric_val})self.top_model_paths = sorted(self.top_model_paths, key=lambda o: o['score'], reverse=not self.decreasing)if len(self.top_model_paths)>self.top_n:self.cleanup()def cleanup(self):to_remove = self.top_model_paths[self.top_n:]logging.info(f"Removing extra models.. {to_remove}")for o in to_remove:os.remove(o['path'])self.top_model_paths = self.top_model_paths[:self.top_n]
How does it work?
This is how the above CheckpointSaver works for each of the four points that I mentioned:
- top_n : When we call the checkpoint saver, inside __call__ method, the condition if len(self.top_model_paths)>self.top_n checks if the total number of models saved is more than top_n, in that case, the checkpoint saver cleans the extra models that have a lower performance.
- The model_path inside the checkpoint saver takes in the model.__class__.__name__ and also the epoch number, thus providing a valid name to save the model.
- All models get stored inside dirpath
- If we are saving models where decreasing metric value is better, that is we are using validation loss to track models, in that case we set decreasing=True. When decreasing is True, then the best validation value in the beginning is np.Inf. Therefore, at every epoch we check if the new metric value is less than the previous value. If it is, then we save the model. Otherwise, we do not save the model. Reverse logic follows for when decreasing=False. We only save the model if the current score is better than the previous epoch's score.
A working example
To see the above CheckpointSaver being used in practice, refer to a working notebook that I've provided here.
After every epoch, model weights get saved if the performance of the new model is better than the previous model.
Summary of saving models using Checkpoint Saver
I hope that by now you understand how the CheckpointSaver works and how it can be used to save model weights after every epoch if the current epoch's model is better than the previous one.
But can you think of reasons why this might not be the best way to save progress?
- Storage: Each model weights are about 100 MBs each, so we will need a few GBs to store the model weights. Generally, it takes a lot of space if you are running a lot of experiments! Also, bigger models take more space.
- Model names and scores: Sure, we get a list of models with model names and epochs as part of the names, but how do we know which experiments were these models a part of? What if we accidentally overwrite the model weights? How do we know what was the score for Resnet_epoch0 model? One quick way could be to add metric scores to model names as well but that would make it really messy. For example, each model would be called something like Resnet_epoch0_0.2313, and IMHO this is definitely not the best way to go about organizing your experiments.
- Store model names in Excel sheet with Experiment names: We would have to manually keep track of experiments that we run and model weights stored for those experiments.
- Which model is better? Since model names have only epoch information associated with them, how do we know if the model weights for epoch 4 were better or that for epoch 3?
- How does model performance track overtime? This way of saving model weights and progress shows us no information about how model performance has changed over time per epoch.
Experiment Tracking with Weights and Biases
If you haven't already read the previous report that's part of this project - How to track all your experiments using Microsoft Excel, I really recommend that you do so to understand how W&B can be used as a powerful tool for experiment tracking.
As part of that report I showcase how to get a beautiful looking dashboard like below:

Figure-1: Sample dashboard on Melanoma project when using W&B for experiment tracking
This means that all our experiments are in one place and we can compare various experiments to see how a small change in parameter values changes in the model's validation performance.
So, now that you already know about experiment tracking, let's also add model versioning on top in the next section to also track model weights after every epoch as part of each experiment.
A more scalable way of saving your model weights
A more efficient and scalable way IMHO would be to use Weights and Biases artifacts.
What are W&B Artifacts?
W&B Artifacts was designed to make it effortless to version your datasets and models, regardless of whether you want to store your files with W&B or whether you already have a bucket you want W&B to track. Once you've tracked your datasets or model files, W&B will automatically log each and every modification, giving you a complete and auditable history of changes to your files.
This lets you focus on the fun and important parts of model training, while W&B handles the otherwise tedious process of tracking all the details.
It's really easy to log wandb artifacts using 4 simple lines of code:
wandb.init()artifact = wandb.Artifact(<enter_filename>, type='model')artifact.add_file(<file_path>)wandb.run.log_artifact(artifact)
So, we already know that W&B can take care of experiment tracking, but how to add model versioning? Let's update the code for Checkpoint Saver that we created before and also add the option to now store model weights to W&B instead of storing them locally.
Before I show you how, here are some of the benefits:
- W&B automatically versions each of the model artifacts for you.
- W&B does not overwrite model artifacts: Even if you store two artifacts with the same name, unless you specify to update a particular version of the model artifact, W&B will create a new model version for you. Please note that this is not possible when you are storing model weights locally as we overwrite files if they have the same name.
- All part of a bigger ecosystem: By using artifacts and adding them to experiment tracking, you become part of a bigger W&B ecosystem and it means that all your relevant data is in one place. So no longer do you need to maintain an excel sheet with the experiments that you run and file paths.
Now, let's update the Checkpoint Saver class to start saving models to W&B as part of the training process after every epoch.
class CheckpointSaver:def __init__(self, dirpath, decreasing=True, top_n=5):"""dirpath: Directory path where to store all model weightsdecreasing: If decreasing is `True`, then lower metric is bettertop_n: Total number of models to track based on validation metric value"""if not os.path.exists(dirpath): os.makedirs(dirpath)self.dirpath = dirpathself.top_n = top_nself.decreasing = decreasingself.top_model_paths = []self.best_metric_val = np.Inf if decreasing else -np.Infdef __call__(self, model, epoch, metric_val):model_path = os.path.join(self.dirpath, model.__class__.__name__ + f'_epoch{epoch}.pt')save = metric_val<self.best_metric_val if self.decreasing else metric_val>self.best_metric_valif save:logging.info(f"Current metric value better than {metric_val} better than best {self.best_metric_val}, saving model at {model_path}, & logging model weights to W&B.")self.best_metric_val = metric_valtorch.save(model.state_dict(), model_path)self.log_artifact(f'model-ckpt-epoch-{epoch}.pt', model_path, metric_val)self.top_model_paths.append({'path': model_path, 'score': metric_val})self.top_model_paths = sorted(self.top_model_paths, key=lambda o: o['score'], reverse=not self.decreasing)if len(self.top_model_paths)>self.top_n:self.cleanup()def log_artifact(self, filename, model_path, metric_val):artifact = wandb.Artifact(filename, type='model', metadata={'Validation score': metric_val})artifact.add_file(model_path)wandb.run.log_artifact(artifact)def cleanup(self):to_remove = self.top_model_paths[self.top_n:]logging.info(f"Removing extra models.. {to_remove}")for o in to_remove:os.remove(o['path'])self.top_model_paths = self.top_model_paths[:self.top_n]
In the above class, we have now added another method called log_artifact. As you can see, it's only three lines of code, thus not much of an overhead to the overall training script.
It's really easy to integrate artifacts into your workflows!
Do you know the best part about using W&B artifacts? I can also add metadata! When we were saving model files locally, we had no idea about the validation scores
For a complete working notebook that showcases how to use W&B for experiment tracking and model versioning - refer to https://gist.github.com/amaarora/9b867f1868f319b3f2e6adb6bfe2373e.
You can also find an example project that was created when I ran the notebook myself here - https://wandb.ai/amanarora/melanoma-artifact/runs/33vyx0zh?workspace=user-amanarora.
Feel free to explore the example project and look at the model artifact. Each artifact has validation score metadata alongside the model weights! Isn't that cool?
Conclusion
I hope that as part of this report, I have been able to showcase how to also use Weights and Biases model artifacts on top of experiment tracking such that you're part of a bigger W&B ecosystem and everything is one place. I have also provided a working notebook with all code snippets and also an example project for you to play around with when you're first trying W&B artifacts.
Feel free to use the CheckpointSaver code for your own experiments and give it a try! If you have any questions please reach out to me at aman@wandb.com.
Thanks for reading!
Add a comment