Skip to main content

Getting started with Weights & Biases in Databricks

A brief guide on how to get the most out of W&B and Databricks
Created on July 19|Last edited on June 21
In this post, we'll walk you through how to get started with W&B alongside Databricks. First up? A quick introduction to W&B:

Weights & Biases (W&B)

Weights & Biases is a MLOps platform built to facilitate collaboration and reproducibility across the machine learning development lifecycle. Machine learning projects can quickly become a mess without some best practices in place to aid developers and scientists as they iterate on models and move them to production.
Weights & Biases is lightweight enough to work with whatever framework or platform teams are currently using, but enables teams to quickly start logging their important results to a central system of record. On top of this system of record, W&B has built visualization, automation, and documentation capabilities for better debugging, model tuning, and project management.

Setting up Weights & Biases in Databricks

The first thing you'll need to do before you start coding in a Databricks notebook and logging your work with W&B is to make sure wandb is installed on your cluster. You'll also need to authenticate to the W&B server when you log runs, so it's best to use the Databricks' secrets API to keep and access your W&B API token:

import os
# Using Databricks Secret to get your wandb API token configured
databricks secrets put --scope my-scope --key WANDB_API_TOKEN
# When prompted, copy your token from your account

# Retrieving your token from within a Databricks notebook and logging in
wandb_api_token = dbutils.secrets.get(scope="my-scope", key="WANDB_API_TOKEN")
os.environ["WANDB_API_TOKEN"] = wandb_api_token
wandb.login()

Logging to the W&B system of record

Experiment Tracking

W&B has a few core primitives which comprise the experiment tracking logging system of the SDK. You can log pretty much anything with W&B: scalar metrics, images, video, custom plots, etc.
To get an idea of the variety of data types you can log, check out the below report, which has code snippets for different media types that may pertain to your use case.

The canonical sections of code which require logging are the training loop and a model evaluation, but you can log any piece of code in your workflow such as data pre-processing, augmentation, generation, etc. All you have to do is call wandb.init() and log diagnostic charts, metrics, and mixed media with wandb.log(). An executed piece of code contextualized by wandb.init() is called a run.
You can also embed rich media and plots into W&B Tables, which provide a persistent, interactive evaluation store for your models. Below is an example training loop and evaluation function which log metrics and predictions as the end of each epoch:
### Generic Training Loop
def train_one_epoch(model, criterion, optimizer, scheduler,
train_dataloader_iter, steps_per_epoch, epoch, batch_size,
device):
wandb.define_metric('train/mini_batch_step')
wandb.define_metric('train/running_loss', step_metric='train/mini_batch_step')
wandb.define_metric('train/running_corrects', step_metric='train/mini_batch_step')
model.train() # Set model to training mode

# statistics
running_loss = 0.0
running_corrects = 0

# Iterate over the data for one epoch.
for step in range(steps_per_epoch):
pd_batch = next(train_dataloader_iter)
inputs, labels = pd_batch['features'].to(device), pd_batch['label_index'].to(device)
# Track history in training
with torch.set_grad_enabled(True):
# zero the parameter gradients
optimizer.zero_grad()

# forward
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)

# backward + optimize
loss.backward()
optimizer.step()

# statistics
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
wandb.log({'train/running_loss': running_loss,
'train/running_corrects': running_corrects,
'train/mini_batch_step': step})
scheduler.step()

epoch_loss = running_loss / (steps_per_epoch * batch_size)
epoch_acc = running_corrects.double() / (steps_per_epoch * batch_size)

return epoch_loss, epoch_acc

### Evaluation Script to Assess Model Errors
def evaluate(model, criterion,
val_dataloader_iter,
validation_steps,
batch_size,
device,
metric_agg_fn=None):
model.eval() # Set model to evaluate mode

# statistics
running_loss = 0.0
running_corrects = 0
columns=["image", "guess", "truth"]
for l, idx in label_to_idx.items():
columns.append("score_" + l)
predictions_table = wandb.Table(columns = columns)
val_inputs_labels = [next(val_dataloader_iter) for i in range(validation_steps)]
val_data = [v['features'] for v in val_inputs_labels]
val_labels = [v['label_index'] for v in val_inputs_labels]
val_data, val_labels = torch.vstack(val_data), torch.vstack(val_labels)
val_labels = torch.reshape(val_labels, (-1,))

# Do not track history in evaluation to save memory
with torch.set_grad_enabled(False):
outputs = model(val_data.to(device))
_, preds = torch.max(outputs, 1)
val_data_permuted = torch.permute(val_data, (0, 2, 3, 1))

for img, preds, gt, scores in zip(val_data_permuted, preds, val_labels, outputs):
row = [wandb.Image(img.numpy()),
idx_to_label[preds.item()],
idx_to_label[gt.item()]]
for s in scores.numpy().tolist():
row.append(np.round(s, 4))
predictions_table.add_data(*row)
loss = criterion(outputs, val_labels)

# statistics
running_loss += loss.item()
running_corrects += torch.sum(preds == val_labels.data)
# Average the losses across observations for each minibatch.
epoch_loss = running_loss / validation_steps
epoch_acc = running_corrects.double() / (validation_steps * batch_size)
return epoch_loss, epoch_acc, predictions_table
Here's a simple example using W&B to log some scalar metrics:

Run set
18



Delta Lake and W&B Artifacts

Artifacts are inputs and outputs of each part of your machine learning pipeline—namely datasets and models.
Datasets in Databricks are typically stored in Delta Lake, which allows you to essentially convert cloud object stores into structured tables with automatic versioning and ACID transaction capabilities.
Training datasets change over time as new data is collected, removed, or re-labeled. Models change with new architectures being implemented along with continuous re-retraining. With these changes, all downstream tasks utilizing the changed datasets and models will be affected and understanding this dependency chain is critical for debugging effectively. W&B can log this dependency graph easily with a few lines of code.

Tracking a Delta Table by reference

W&B can track a reference to the delta table we are using and attach the transaction log as an interactive table where we can see who has made what changes to the table over time.
run = wandb.init(project=project_name, entity=entity, job_type='delta_ingest')
# Track a reference to the delta table itself
df_artifact = wandb.Artifact(name='flowers_delta_table', type="delta_table")
df_artifact.add_reference('file:///dbfs/tmp/delta/flower_photos')

# Log the transaction log as an interactive table
delta_history = deltaTable.history().toPandas()
wandb_delta_history = wandb.Table(dataframe=delta_history)
df_artifact.add(wandb_delta_history, name='delta_history')

run.log_artifact(df_artifact)
wandb.finish()

flowers_delta_table
Artifact overview
Type
delta_table
Created At
July 19th, 2022
Description

|-- size: struct (nullable = true)

| |-- width: integer (nullable = true)

| |-- height: integer (nullable = true)

|-- label: string (nullable = true)

|-- content: binary (nullable = true)

Versions
Version
Aliases
Logged By
Tags
Created
TTL Remaining
# of Consuming Runs
Size
2
latest
v2
Tue Jul 19 2022
Inactive
1
7.6MB
1
v1
Tue Jul 19 2022
Inactive
2
6.1kB
0
v0
Tue Jul 19 2022
Inactive
1
2.6kB
Loading...

Declaring dependency on Delta Table versions:

For downstream runs which are utilizing a specific version of the Delta Table, all we have to do is log the version we are using in wandb.config.
Here we are converting the delta table to a different format with Petastorm before loading it into a dataloader for training. We want to keep track of which version of the delta table we use in this process so that we can tie all the downstream models we build to this particular version. We can also track a reference to the Petastorm dataset:
run = wandb.init(project=project_name, entity=entity, job_type='convert_to_petastorm')
# Set a cache directory on DBFS FUSE for intermediate data.
run.use_artifact('flowers_delta_table:latest')
latest_version = deltaTable.history().toPandas()['version'].sort_values(ascending=False).iloc[0]
wandb.config.delta_table_version = latest_version
spark.conf.set(SparkDatasetConverter.PARENT_CACHE_DIR_URL_CONF, "file:///dbfs/tmp/petastorm/cache")

converter_train = make_spark_converter(df_train)
converter_val = make_spark_converter(df_val)

df_artifact = wandb.Artifact(name='flowers_data_splits', type="petastorm_dataset", metadata={"delta_table_version": latest_version})
df_artifact.add_reference('file:///dbfs/tmp/petastorm/cache')
run.log_artifact(df_artifact)
wandb.finish()

flowers_data_splits
Artifact overview
Type
petastorm_dataset
Created At
July 15th, 2022
Description
Versions
Version
Aliases
Logged By
Tags
Created
TTL Remaining
# of Consuming Runs
Size
m.delta_table_version
2
latest
v2
Tue Jul 19 2022
Inactive
1
0B
10
1
v1
Tue Jul 19 2022
Inactive
17
0B
-
0
v0
Fri Jul 15 2022
Inactive
2
0B
-
Loading...
With W&B Artifacts, we obtain a complete picture of a machine learning pipeline so we can better understand how and where issues arise and isolate the problem area.

Model registry

W&B's model registry provides a centralized place to house those promising model versions which are critical to your projects and ML tasks. It enables users to "bookmark" the models that need attention, making visible and identifiable the models they want to move to production. Users can also enter in rich markdown or add companion artifacts like sample predictions or HTML to create rich model cards.
In Databricks we might split our workflow into two notebooks, training and evaluation/inference. We can use dbutils to parametrize these notebooks with arguments referring to the W&B Model Registry. In the training notebook, we will specify which model collection to register our best trained model to and in the inference notebook, we'll pick up the current best version from the registry for testing.



In this notebook, we log the model after each epoch and the best one will be linked to the collection "Flowers Prediction"
for epoch in range(num_epochs):
train_loss, train_acc = train_one_epoch(model, criterion, optimizer, exp_lr_scheduler,
train_dataloader_iter, steps_per_epoch, epoch, batch_size,
device)
val_loss, val_acc, predictions_table = evaluate(model, criterion, val_dataloader_iter, validation_steps, batch_size, device)
wandb.log({'train/train_loss': train_loss,
'train/train_acc': train_acc,
'validation/val_loss': val_loss,
'validation/val_acc': val_acc,
'validation_predictions': predictions_table})
is_best = val_loss < best_loss
if is_best:
best_loss = val_loss
art = wandb.Artifact(f"flowers-prediction-{architecture_name}-{wandb.run.id}", "model")
torch.save(model, "model.pt", pickle_module=cloudpickle)
art.add_file("model.pt")
wandb.log_artifact(art, aliases=["best", "latest"] if is_best else None)
if is_best:
best_model = art
wandb.run.link_artifact(best_model, model_collection_name, ["latest"])
wandb.finish()

In the inference notebook, we retrieve the latest version of the registered model via "Flowers Prediction:latest"

In the inference UDF, we retrieve this model and run predictions through it:
run = wandb.init(project=project_name, entity=entity, group=group, job_type='inference')
model_art = run.use_artifact(model_collection_name, type='model')
artifact_dir = model_art.download()
model = torch.load(artifact_dir + '/model.pt')
model.eval()
We can view and compare all the versions of models which have been registered in the past, add aliases do indicate stage and maturity, and finally add rich documentation around what the model expects as input and produces as output:

Flowers Prediction
Model card
Type
model
Created At
July 15th, 2022
Description

“I was left alone there in the company of the orchids, roses and violets, which, like people waiting beside you who do not know you, preserved a silence which their individuality as living things made all the more striking, and warmed themselves in the heat of a glowing coal fire...”

― Marcel Proust

This model is a fine-tuned version of mobilenet_v2 from torch vision, trained on the flowers dataset to predict 5 species of flowers: sunflowers, dandelion, roses, tulips, and daisies.

Expected Inputs:

For inference or training, preprocessing consists of resizing to 256x256, center-cropping and normalization

transform = transforms.Compose([
      transforms.Resize(256),
      transforms.CenterCrop(224),
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

Model Raw Output:

Batched tensor of 5 probabilities for the different flower species. Below is the index number for each species, which you can use to post-process the scores:

{
     'daisy': 0, 
     'dandelion': 1, 
     'roses': 2, 
     'sunflowers': 3,
     'tulips': 4
}

References

  1. Mobilenet_v2 paper: https://arxiv.org/abs/1801.04381
  2. Flowers Dataset
  3. Databricks reference solution
Versions
Version
Aliases
Logged By
Tags
Created
TTL Remaining
# of Consuming Runs
Size
No rows found
Loading...

Interactive tables

W&B Tables enable a granular analysis of predictions and results through tabular data manipulation. Oftentimes, understanding a model's behavior during or after training requires more than seeing a clean loss curve go down and to the right. We need to understand where specifically the model fails, what examples are giving it trouble, where we might need to collect more training data/re-label, or maybe even uncover more nuanced errors like numerical instability.
Tables can be used as a model evaluation store, which stores consolidated results on golden validation datasets across different trained models in your project. They can also be used as model leaderboards, where each row is a model class or architecture with embedded explainability or custom performance charts alongside them. These are both best practices which you can start incorporating with a few lines of code.
Simply pass in rich media into wandb.Table and you will have a persistent, interactive visualization to compare models on individual samples.



Run set
18


Aside: Spotting data skew using W&B Tables and Spark UDFs

When performing inference in Databricks, we often want to distribute the work across the spark cluster. We can do this using Pyspark User-defined-functions (UDFs), specifically the scalar iter pandas udf variant, which allows the initialization of some state within the UDF before performing the work within the function.
We can build a UDF which 1) pulls the model from the W&B Registry 2) initializes a W&B run 3) Tracks some metrics and predictions. See below:
def model_inference_udf(project_name,
entity,
group,
model_collection_name):
"""
1. Pulls down a model from the W&B Model Registry,
2. Creates a pandas UDF from the model
3. Runs inference on a test dataset
4. Filters predictions to just errors and logs into Table
"""
def predict(batch_iter):
run = wandb.init(project=project_name, entity=entity, group=group, job_type='inference')
model_art = run.use_artifact(model_collection_name, type='model')
artifact_dir = model_art.download()
model = torch.load(artifact_dir + '/model.pt')
model.eval()
run_name = run.name
columns=["image", "guess", "truth", "run_name"]
for l, idx in label_to_idx.items():
columns.append("score_" + l)
for content_series, label_series in batch_iter:
dataset = ImageNetDataset(list(content_series), list(label_series))
loader = DataLoader(dataset, batch_size=64)
predictions_table = []
with torch.no_grad():
for image_batch, label_batch in loader:
outputs = model(image_batch.to(device))
_, preds = torch.max(outputs, 1)
image_batch = torch.permute(image_batch, (0, 2, 3, 1))
rows = []
for img, preds, gt, scores in zip(image_batch, preds, label_batch, outputs):
row = [wandb.Image(img.numpy()),
idx_to_label[preds.item()],
gt, run_name]
for s in scores.numpy().tolist():
row.append(np.round(s, 4))
rows.append(row)
table_batch = pd.DataFrame(rows, columns=columns)
predictions_table.append(table_batch)
predictions_table = pd.concat(predictions_table, axis=0)
#Filter to just the errored predictions to avoid logging huge datasets
wandb.log({'inference_results': wandb.Table(dataframe=predictions_table)})
yield predictions_table[['guess', 'truth']]
wandb.finish()

return_type = "guess: string, truth: string"
return pandas_udf(f=predict, returnType=return_type, functionType=PandasUDFType.SCALAR_ITER)
When we run inference and write predictions to the delta table, this will invoke the UDF across the Spark cluster, logging compute utilization and example predictions on each worker node.
This is an easy way to spot data skew, one of the biggest issues plaguing Spark users. Looking at the row counts processed by the different run names, we can see that upbeat-snowball-37 is processing barely any data and so we are under-utilizing that worker node. We also see in the CPU and memory utilization, that this run barely did any work. Slacker.

Run set
4


Hyperparameter sweeps

One of the more tedious aspects of training deep learning models is tuning hyperparameters. When we log runs in W&B, we can make W&B aware of hyperparameters.
A central sweep controller can then delegate new hyper-parameter combinations based on a set of distributions we specify across the hyper-parameter space through a .yaml file.
If we do a Bayes search, W&B can even seed the search with previous runs we've already logged. Below is an example where we have a simple training function which exposes several hyperparameters to W&B via wandb.config. W&B sweep then initializes a hyperparameter search using the dictionary of distributions for the hyperparameter space. We also recently added nested sweep configs, which allows for more flexibility in the types of searches you can do.
def train(config=None):
# Initialize a new wandb run
with wandb.init(config=config):
# If called by wandb.agent, as below,
# this config will be set by Sweep Controller
config = wandb.config

loader = build_dataset(config.batch_size)
network = build_network(config.fc_layer_size, config.dropout)
optimizer = build_optimizer(network, config.optimizer, config.learning_rate)

for epoch in range(config.epochs):
avg_loss = train_epoch(network, loader, optimizer)
wandb.log({"loss": avg_loss, "epoch": epoch})

sweep_config = {
'method': 'random',
# 'method': 'grid',
# 'method': 'bayes',
}

parameters_dict = {
'optimizer': {
'values': ['adam', 'sgd']
},
'fc_layer_size': {
'values': [128, 256, 512, 1024]
},
'dropout': {
'values': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
},
# Static hyperparameter, notice singular key
'epochs': {
'value': 10
},
'learning_rate': {
# Flat distribution between 0 and 0.1
'distribution': 'uniform',
'min': 0,
'max': 0.25
},
'batch_size': {
'distribution': 'q_log_uniform',
'q': 1,
'min': math.log(32),
'max': math.log(256),
}
}

### Initializes the central sweep server
sweep_config['parameters'] = parameters_dict
sweep_id = wandb.sweep(sweep_config, project="sweeps-demo-pytorch")

### Run this in multiple machines/cores to distribute the hyperparameter search
wandb.agent(sweep_id, train, count=10)


W&B will automatically separate the runs associated with a sweep and create some charts automatically that allow us to do more meta-analysis on which combinations are working well.

Sweep: 0ewut602 1
10
Sweep: 0ewut602 2
0


Reports

W&B Reports help contextualize and document the system of record built through logging diagnostics and results from different pieces of your pipeline. Reports are interactive and dynamic, reflecting filtered run sets logged in W&B. In fact, you're reading one now. You can add all sorts of assets to a report, including plots, tables, images, code, and nested reports.
Whether you are writing technical summaries, regulatory documentation, or just want a real-time dashboard reflecting the progress of your team, reports can be a best practice documentation layer for your data science and machine learning projects. Check out the below gallery for some interesting ideas:


Going beyond the core W&B primitives

wandb.log, wandb.Artifact, wandb.Table, and wandb.sweep can take you far in building your machine learning system of record, forming the core of some best practices we see top machine learning research teams employ in their everyday workflows. Beyond these primitives, our team continues to build out integrations with higher level frameworks and tools whereby simply adding a single W&B callback or function argument causes everything to be automatically logged under the hood.
Check out our integrations page and double check the docs of your favorite machine learning repo as their might be already be a W&B integration in place! Let us know if you'd like to see W&B integrated in a package or tool we aren't yet logging!

References


Iterate on AI agents and models faster. Try Weights & Biases today.
artifact
artifact
artifact