W&B Integration Best Practices
Integrations Guide
This article discusses everything a maintainer or anyone raising a pull request need to know about integrating the wandb
client into a python library or framework.
Installation
Depending on the design of the library there are multiple ways to install wandb
or include it in the dependency tree of the library installation. It is often recommended to include wandb
as a dependency in the library’s setup.py
or other similar dependency management files such as pyproject.toml
. This step ensures that wandb
is installed as a dependency whenever a user installs the library using pip
.
An example that includes wandb
as a dependency in the requirements.txt
.
# YOLOv5 requirements
# Logging
wandb
~~# --- other requirements ---~~
An alternate pattern is to check for an existing installation of wandb
in the environment when the user accesses a feature such as wandb.init
. If an existing installation is not found, report a useful error message that directs the user to install wandb
in their project environment.
Here’s an example of this pattern.
def is_wandb_available():
try:
import wandb
return True
except ModuelNotFoundError:
return False
---- Uasage ----
if not is_wandb_available():
raise RuntimeError("You picked the wandb backend, but it is not installed. Use `pip install wandb`.")
Initialization
Before you begin integrating wandb
features into the library, check if an existing wandb
callback already exists for the framework the library is using. Many ML training frameworks implement a callback mechanism to log configurations, metrics, and artifacts. wandb
already provides integrations in some existing frameworks, popular repositories, and libraries. Also, if you would like to use or see a feature that is not implemented in the callback consider raising an issue or PR here.
Setup and Authentication
The user is also required to log in to weights & biases so that the experiment logs and artifacts can be synced to a centralized wandb server. The recommended way for a use in a library installation is to reference the environmental variable WANDB_API_KEY
for authentication. Additional setup steps can be found here
Depending on the design of the library, there are multiple ways to integrate W&B. The first step is to initialize wandb
. This starts a new wandb run
- A unit of computation logged by wandb
, Typically this is an ML experiment.
wandb.init(project='project_name', name='run_name', config={...}, ...)
More details about wandb.init
parameters and usage can be found in the docs.
Best practices to set up and initialize wandb runs:
-
Check if a wandb run already exists before initializing a new run.
run = wandb.run if wandb.run else wandb.init()
-
Reference environmental variables like
WANDB_ENTITY
andWANDB_PROJECT
and whenever possible. -
Ensure that the wandb process is initialized only once in the main process. If there are multiple wandb initializations ensure that there are grouped using the
group
argument inwandb.init
-
Log existing configurations and experimental parameters by updating
wandb.config
Here’s a simplified excerpt from the transformers integration highlighting some of these best practices when setting up the wandb client.
# https://github.com/huggingface/transformers/blob/df15703b422644c7cbbb9779af9e8f1fd639eefe/src/transformers/integrations.py#L555
def setup(args, state, config):
# 3: ensure wandb is initialized in the zero process
if state.is_world_process_zero:
# 1: Ensure that an existing run is not in progress
if self._wandb.run is None:
self._wandb.init(
# 2: use envvar if possible
entity=os.getenv("WANDB_ENTITY")
project=os.getenv("WANDB_PROJECT", "my_project")
**init_args, # add any other relevant kwargs here
)
# 4: add config parameters (run may have been created manually)
self._wandb.config.update(combined_dict, allow_val_change=True)
In a distributed environment the wandb
client must be initialized in the parent process. This ensures that the logs are always pushed only by a single process. More information on distributed logging can be found here
Configuration and Environment Variables
Make use of wandb.config
to hold all of the hyperparameters associated with a wandb run.
wandb.init(config={"epochs": 4, "batch_size": 32})
# later
wandb.config.update({"backbone_type": "resnet", "channels": 16})
You can treat this object as a python dictionary that stores the hyperparameters of the experiment being logged. It’s recommended that you set the config as a parameter to wandb.init
, i.e. wandb.init(config=my_config_dict)
. However, depending on the design of the library there are also other ways to pass the config to wandb
.
A wandb.config
can be initialized using other python objects such as argparse, absl, configdict, NameSpace
, and even yaml
files. Here’s an example of a YAML configuration that must be named config-defaults.yaml
# config-defaults.yaml
# sample config defaults file
epochs:
value: 100
batch_size:
value: 32
The guide details this and many other wandb configurations mechanisms in further details
Another way to configure options in wandb
is by using environment variables. It’s often recommended that you inform the user in the installation guide or README.md
of the library to set the required environment variables. A comprehensive list of Environment variables specific to wandb configurations can be found here
Logging
Once initialized, you can log metrics, media, configurations, models, datasets, and other relevant information for tracking your experiments using wandb.log
The simplest way to log metrics and media is to call wandb.log
and pass it a python dictionary of metrics that you wish to track.
wandb.log({'acc': acc, 'loss': loss,})
A detailed guide and walkthrough using wandb.log
are documented here.
Commit and Step
By default information logged to wandb is traced over time. However, this behavior can be overridden using the step
argument in wandb.log
. For instance, you can log multiple keys with the same step
. In this case, W&B will write all the collected keys and values to the history. It’s important to note that “steps” are always incremental i.e. you can only log either in the current step or the next one. You can also set commit=False
in wandb.log
to accumulate metrics, just be sure to eventually call wandb.log
with commit=True
to persist the metrics. For example, commit can also be used when you want to collect metrics over each batch and then log them in at the end of an epoch.
Here’s a simple example using step and commit to change the default behavior.
# step example
for i in range(10) :
wandb.log({'loss': 0.2}, step=i)
wandb.log({'accuracy': 0.2}, step=i)
#commit example
for batch in epoch:
# log for each batch in an epoch
wandb.log({'loss': 0.2}, commit=False)
# at the end of a epoch
wandb.log({'accuracy': 0.8})
When logging media, plots and other non-native datatypes you must ensure that the right wandb.data_type is used. For instance, use the wand.Image
datatype when logging images while using [wandb.Audio](http://wandb.Audio)
to log audio files. This ensures that these data types are visualized on the dashboard. Here’s a quick example to log an image datatype
wandb.log({"my_image": wandb.Image('./my_image.png'), ...})
It’s highly recommended that you use artifacts and aliases
when logging media in your library. You’ll learn more about artifacts and how to use them later in this guide. Here’s a short example of logging a dataset and then referencing an image from the dataset.
wandb.log_artifact(val_data, name="val_data", type="dataset", aliases=["my-cool-alias"])
art = wandb.use_artifact("my_images:my-cool-alias")
img_1 = art.get(path_to_img1)
wandb.log({"image": img_1})
Define metrics
Many ML frameworks allow control over the frequency of logging and evaluation over mini-batches, steps, and epochs. By default, calling wandb.log
for metrics tracks them over time. This tracks the history of the metric being logged in the run and makes the time the default x-axis in all wandb plots on the dashboard. You can change this behavior by defining custom metrics using wandb.define_metric
. Here’s a simple example of how this can be done.
# Define the custom x axis metric
wandb.define_metric("custom_step")
# Define which metrics to plot against that x-axis
wandb.define_metric("validation/loss", step_metric='custom_step')
for i in range(10):
log_dict = {
"train/loss": 1/(i+1),
"custom_step": i**2,
"validation/loss": 1/(i+1)
}
wandb.log(log_dict)
If you wish to have control over the frequency of logging you can also use the step
argument in wandb.log
to change this behavior.
For instance, it might make sense to accumulate all the metrics from a given operation in a dictionary and log them together with a single call to wandb.log
. You can either log multiple steps per epoch or log one step in each epoch by averaging across the metrics of mini-batches.
More information related to incremental and stepwise logging can be found in the documentation here
Artifacts
A wandb.Artifact
can be used to log datasets and models, automatically version them, and visualize and query datasets in the dashboard. This allows you to trace the flow of data through your pipeline. To ensure reproducibility of hyperparameters, model configurations and datasets make sure that you log various artifacts in the pipeline of the library.
Think of an artifact as a versioned folder of data. You can store entire datasets and models directly in artifacts, or use artifact references to point to data in other systems.
Depending on the model size and the frequency of checkpoints created for a particular task, you can devise a strategy for logging model artifacts and datasets.
For example, model artifacts can be named “run-<run_id>-model”
, where each checkpoint is its own version for a given run.
Here's how you log the above model artifact.
# class WandbLogger(BaseLogger_if_needed) continued ..
def log_model(self, model_path):
wandb.log_artifact(model_path, name='run_' + wandb.run.id + '_model', type='model')
When you log artifacts, you get a directed graph of the operations that used the artifacts as input or logged them as outputs. Here's an example artifacts graph.
You can then access these artifacts in your scripts via their artifacts API.
You can log datasets as wandb.Table
objects. When you log a dataset as a wandb.Table
object, you can visualize the entire dataset directly on the wandb dashboard. You can compare versions, run queries and perform statistical analysis directly in your browser
Further guides and references to artifact usage can be found in the Artifacts Reference.
De-duplication with Artifacts
To avoid duplication of media such as images and audio you can also log these datatypes as a wandb.artifact
and reference it in the wandb.Table
using the wandb.use_artifact
.
# create an artifact to store your dataset
artifact = wandb.Artifact(name="my_dataset", type="dataset")
# list the images to be logged in the artifact
img_files = [path for path in dataset.image_paths]
# create an empty table to log the media
table = wandb.Table(columns=["id", "image",])
# add the images to the artifact
for img_file in img_files:
artifact.add_file(img_file, name='data/images/' + img_file.name)
# add the images to the table referencing the same path
for si, img in enumerate(img_files):
table.add_data(si, wandb.Image(img)
# add the table to the artifact
artifact.add(table, name)
wandb.log_artifact(artifact)
## --------- reuse the artifact later----- ##
wandb.use_artifact(artifact)
artifact.wait() # this downloads the artifact if it's not already present
table = artifact.get("my_dataset")
wandb.log({"my_table": table})
Distributed Logging
Some libraries consist of various modules, metrics, configurations, outputs, and model checkpoints that aren’t accessible from a single module. In such cases, the best option would be to create a wrapper around wandb's vanilla logger. Here's a simplified skeleton of one of many possible wrapper classes.
class WandbLogger(BaseLogger_if_needed):
def __init__(self,**kwargs):
'''
initialize wandb. kwargs is supposed to contain project, name, config and/or
other wandb.init() arguments
'''
self.run = wandb.init(**kwargs) if not wandb.run else wandb.run
self.log_dict = {} # dict to accumulate metrics and/or media
...
def log(self, log_dict, flush=False):
'''
This function accumulates data in self.log_dict. If flush is set to True,
the accumulated metrics will be logged directly to wandb dashboard.
'''
for key, value in log_dict.items():
self.log_dict[key] = value
if flush: # for cases where you don't want to accumulate data
self.flush()
...
def flush(self):
'''
This function logs the accumulated data to wandb dashboard. Practically, this
function when called once per epoch, logs data for that epoch/step
'''
wandb.log(self.log_dict)
self.log_dict = {}
def finish(self):
'''
Finish this W&B run
'''
self.run.finish()
The above skeleton can be modified to suit the needs of a particular library.
Sweeps
Use wandb.sweep
to automate hyperparameter tuning and explore the space of possible models.
The sweeps feature can be used both via the cli interface and the python api. This quickstart guide should get you started on how to configure sweeps in your project.
Again, it’s crucial to remember that sweeps and agents need to be started and run from the parent process. Additionally, it’s important to pass sweep configurations before running sweeps. Check the documentation here for more information on using sweeps. Additionally, sweeps can also be configured using YAML files. Here’s an example of a YAML configuration using sweeps:
#https://github.com/wandb/examples/blob/master/examples/keras/keras-cnn-fashion/sweep-grid-hyperband.yaml
program: train.py
method: grid
metric:
goal: minimize
name: val_loss
parameters:
dropout:
values: [0.15, 0.2, 0.25, 0.3, 0.4]
hidden_layer_size:
values: [96, 128, 148]
momentum:
values: [0.8, 0.9, 0.95]
epochs:
value: 27
early_terminate:
type: hyperband
s: 2
eta: 3
max_iter: 27
sweep_id = wandb.sweep() # this will automatically read the sweep.yaml from above
wandb.agent(sweep_id, function=_objective, count=n_trials)
Checklist and FAQ for authors:
- Log the experiment configurations
The overview section of a run in the wandb dashboard contains the configurations passed as a
config dict
inwandb.init
. It is recommended to log all the arguments passed in that particular pipeline that is responsible for generating a run, making it easier to reproduce the results. More on config
- Set
job_type
for distinct parts of pipelineswandb.init
supportsjob_type
parameter that uses string names to group different parts of a pipeline. For example, you can group train, test, and evaluation runs by setting thejob_type
to"train"
,"test"
or"eval"
. Here's a stylegan2 dashboard where the runs are grouped byjob_type
.If you set grouping in your script, we will group the runs by default in the table in the UI. You can toggle this on and off by clicking the Group button at the top of the table. You can read more about grouping here. Click the edit button in the upper right corner of a graph and select the Advanced tab to change the line and shading. You can select the mean, minimum, or maximum value for the line in each group
- Log the metrics
Log the relevant metrics with
wand.log
and ensure that you have control over thestep
andcommit
as required by you application. Ensure to give relevant names to the metrics being logged. For instance, you use/
as a delimiter to split the same metric acrosstrain
validation
andtest
like so.train/loss
,val/loss
andtest/loss
. - Model Checkpointing
Ensure that model checkpoints are logged as
wandb.Artifact
where applicable. This allows complete repeatability of experiments. Additionally, this also allows you to share the models weights along with results of an experiment and reuse models from the right experiment at the inference stage. - Log media
Along with metrics, you can also log the outputs of evaluation results to showcase the training progress. Depending on the task, you can log images, video, audio, tables, interactive bounding box images, segmentation maps, point clouds, and more. Here's an example dashboard. The media panel can be used to visualize the improvement in accuracy across steps. More on how to log interactive media
Example of bounding box media panel for the object detection task
- Log Tables for model evaluation In the training loop include logging of tables of evaluation predictions and results. This allows you to not only track metrics but also view the model performance over time (epochs/batches). Furthermore, if the library includes media, logging media to tables will allow you to visualize the model predictions over different epochs.
- Executable colabs It is recommended to have quick colabs in the README of the repository that introduce users to the W&B integration. If the training pipeline uses very heavy models, you can showcase only the inference part using a pre-trained model. Check out the wandb example colabs repo for examples of many tasks, frameworks, and libraries.
- Testing Consider adding the wandb SDK test-suite as part of the library’s unit tests and continuous integration pipeline. You can find further details related to the SDK testing here