Getting started with Weights & Biases in Databricks
Weights & Biases (W&B)
Setting up Weights & Biases in Databricks

import os# Using Databricks Secret to get your wandb API token configureddatabricks secrets put --scope my-scope --key WANDB_API_TOKEN# When prompted, copy your token from your account# Retrieving your token from within a Databricks notebook and logging inwandb_api_token = dbutils.secrets.get(scope="my-scope", key="WANDB_API_TOKEN")os.environ["WANDB_API_TOKEN"] = wandb_api_tokenwandb.login()
Logging to the W&B system of record
Experiment Tracking
### Generic Training Loopdef train_one_epoch(model, criterion, optimizer, scheduler,train_dataloader_iter, steps_per_epoch, epoch, batch_size,device):wandb.define_metric('train/mini_batch_step')wandb.define_metric('train/running_loss', step_metric='train/mini_batch_step')wandb.define_metric('train/running_corrects', step_metric='train/mini_batch_step')model.train() # Set model to training mode# statisticsrunning_loss = 0.0running_corrects = 0# Iterate over the data for one epoch.for step in range(steps_per_epoch):pd_batch = next(train_dataloader_iter)inputs, labels = pd_batch['features'].to(device), pd_batch['label_index'].to(device)# Track history in trainingwith torch.set_grad_enabled(True):# zero the parameter gradientsoptimizer.zero_grad()# forwardoutputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)# backward + optimizeloss.backward()optimizer.step()# statisticsrunning_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)wandb.log({'train/running_loss': running_loss,'train/running_corrects': running_corrects,'train/mini_batch_step': step})scheduler.step()epoch_loss = running_loss / (steps_per_epoch * batch_size)epoch_acc = running_corrects.double() / (steps_per_epoch * batch_size)return epoch_loss, epoch_acc### Evaluation Script to Assess Model Errorsdef evaluate(model, criterion,val_dataloader_iter,validation_steps,batch_size,device,metric_agg_fn=None):model.eval() # Set model to evaluate mode# statisticsrunning_loss = 0.0running_corrects = 0columns=["image", "guess", "truth"]for l, idx in label_to_idx.items():columns.append("score_" + l)predictions_table = wandb.Table(columns = columns)val_inputs_labels = [next(val_dataloader_iter) for i in range(validation_steps)]val_data = [v['features'] for v in val_inputs_labels]val_labels = [v['label_index'] for v in val_inputs_labels]val_data, val_labels = torch.vstack(val_data), torch.vstack(val_labels)val_labels = torch.reshape(val_labels, (-1,))# Do not track history in evaluation to save memorywith torch.set_grad_enabled(False):outputs = model(val_data.to(device))_, preds = torch.max(outputs, 1)val_data_permuted = torch.permute(val_data, (0, 2, 3, 1))for img, preds, gt, scores in zip(val_data_permuted, preds, val_labels, outputs):row = [wandb.Image(img.numpy()),idx_to_label[preds.item()],idx_to_label[gt.item()]]for s in scores.numpy().tolist():row.append(np.round(s, 4))predictions_table.add_data(*row)loss = criterion(outputs, val_labels)# statisticsrunning_loss += loss.item()running_corrects += torch.sum(preds == val_labels.data)# Average the losses across observations for each minibatch.epoch_loss = running_loss / validation_stepsepoch_acc = running_corrects.double() / (validation_steps * batch_size)return epoch_loss, epoch_acc, predictions_table
Delta Lake and W&B Artifacts
Tracking a Delta Table by reference
run = wandb.init(project=project_name, entity=entity, job_type='delta_ingest')# Track a reference to the delta table itselfdf_artifact = wandb.Artifact(name='flowers_delta_table', type="delta_table")df_artifact.add_reference('file:///dbfs/tmp/delta/flower_photos')# Log the transaction log as an interactive tabledelta_history = deltaTable.history().toPandas()wandb_delta_history = wandb.Table(dataframe=delta_history)df_artifact.add(wandb_delta_history, name='delta_history')run.log_artifact(df_artifact)wandb.finish()
|-- size: struct (nullable = true)
| |-- width: integer (nullable = true)
| |-- height: integer (nullable = true)
|-- label: string (nullable = true)
|-- content: binary (nullable = true)
Declaring dependency on Delta Table versions:
run = wandb.init(project=project_name, entity=entity, job_type='convert_to_petastorm')# Set a cache directory on DBFS FUSE for intermediate data.run.use_artifact('flowers_delta_table:latest')latest_version = deltaTable.history().toPandas()['version'].sort_values(ascending=False).iloc[0]wandb.config.delta_table_version = latest_versionspark.conf.set(SparkDatasetConverter.PARENT_CACHE_DIR_URL_CONF, "file:///dbfs/tmp/petastorm/cache")converter_train = make_spark_converter(df_train)converter_val = make_spark_converter(df_val)df_artifact = wandb.Artifact(name='flowers_data_splits', type="petastorm_dataset", metadata={"delta_table_version": latest_version})df_artifact.add_reference('file:///dbfs/tmp/petastorm/cache')run.log_artifact(df_artifact)wandb.finish()
Model registry

for epoch in range(num_epochs):train_loss, train_acc = train_one_epoch(model, criterion, optimizer, exp_lr_scheduler,train_dataloader_iter, steps_per_epoch, epoch, batch_size,device)val_loss, val_acc, predictions_table = evaluate(model, criterion, val_dataloader_iter, validation_steps, batch_size, device)wandb.log({'train/train_loss': train_loss,'train/train_acc': train_acc,'validation/val_loss': val_loss,'validation/val_acc': val_acc,'validation_predictions': predictions_table})is_best = val_loss < best_lossif is_best:best_loss = val_lossart = wandb.Artifact(f"flowers-prediction-{architecture_name}-{wandb.run.id}", "model")torch.save(model, "model.pt", pickle_module=cloudpickle)art.add_file("model.pt")wandb.log_artifact(art, aliases=["best", "latest"] if is_best else None)if is_best:best_model = artwandb.run.link_artifact(best_model, model_collection_name, ["latest"])wandb.finish()

run = wandb.init(project=project_name, entity=entity, group=group, job_type='inference')model_art = run.use_artifact(model_collection_name, type='model')artifact_dir = model_art.download()model = torch.load(artifact_dir + '/model.pt')model.eval()
“I was left alone there in the company of the orchids, roses and violets, which, like people waiting beside you who do not know you, preserved a silence which their individuality as living things made all the more striking, and warmed themselves in the heat of a glowing coal fire...”
― Marcel Proust
This model is a fine-tuned version of mobilenet_v2 from torch vision, trained on the flowers dataset to predict 5 species of flowers: sunflowers, dandelion, roses, tulips, and daisies.
Expected Inputs:
For inference or training, preprocessing consists of resizing to 256x256, center-cropping and normalization
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
Model Raw Output:
Batched tensor of 5 probabilities for the different flower species. Below is the index number for each species, which you can use to post-process the scores:
{
'daisy': 0,
'dandelion': 1,
'roses': 2,
'sunflowers': 3,
'tulips': 4
}
References
- Mobilenet_v2 paper: https://arxiv.org/abs/1801.04381
- Flowers Dataset
- Databricks reference solution
Interactive tables
Aside: Spotting data skew using W&B Tables and Spark UDFs
def model_inference_udf(project_name,entity,group,model_collection_name):"""1. Pulls down a model from the W&B Model Registry,2. Creates a pandas UDF from the model3. Runs inference on a test dataset4. Filters predictions to just errors and logs into Table"""def predict(batch_iter):run = wandb.init(project=project_name, entity=entity, group=group, job_type='inference')model_art = run.use_artifact(model_collection_name, type='model')artifact_dir = model_art.download()model = torch.load(artifact_dir + '/model.pt')model.eval()run_name = run.namecolumns=["image", "guess", "truth", "run_name"]for l, idx in label_to_idx.items():columns.append("score_" + l)for content_series, label_series in batch_iter:dataset = ImageNetDataset(list(content_series), list(label_series))loader = DataLoader(dataset, batch_size=64)predictions_table = []with torch.no_grad():for image_batch, label_batch in loader:outputs = model(image_batch.to(device))_, preds = torch.max(outputs, 1)image_batch = torch.permute(image_batch, (0, 2, 3, 1))rows = []for img, preds, gt, scores in zip(image_batch, preds, label_batch, outputs):row = [wandb.Image(img.numpy()),idx_to_label[preds.item()],gt, run_name]for s in scores.numpy().tolist():row.append(np.round(s, 4))rows.append(row)table_batch = pd.DataFrame(rows, columns=columns)predictions_table.append(table_batch)predictions_table = pd.concat(predictions_table, axis=0)#Filter to just the errored predictions to avoid logging huge datasetswandb.log({'inference_results': wandb.Table(dataframe=predictions_table)})yield predictions_table[['guess', 'truth']]wandb.finish()return_type = "guess: string, truth: string"return pandas_udf(f=predict, returnType=return_type, functionType=PandasUDFType.SCALAR_ITER)
Hyperparameter sweeps
def train(config=None):# Initialize a new wandb runwith wandb.init(config=config):# If called by wandb.agent, as below,# this config will be set by Sweep Controllerconfig = wandb.configloader = build_dataset(config.batch_size)network = build_network(config.fc_layer_size, config.dropout)optimizer = build_optimizer(network, config.optimizer, config.learning_rate)for epoch in range(config.epochs):avg_loss = train_epoch(network, loader, optimizer)wandb.log({"loss": avg_loss, "epoch": epoch})sweep_config = {'method': 'random',# 'method': 'grid',# 'method': 'bayes',}parameters_dict = {'optimizer': {'values': ['adam', 'sgd']},'fc_layer_size': {'values': [128, 256, 512, 1024]},'dropout': {'values': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]},# Static hyperparameter, notice singular key'epochs': {'value': 10},'learning_rate': {# Flat distribution between 0 and 0.1'distribution': 'uniform','min': 0,'max': 0.25},'batch_size': {'distribution': 'q_log_uniform','q': 1,'min': math.log(32),'max': math.log(256),}}### Initializes the central sweep serversweep_config['parameters'] = parameters_dictsweep_id = wandb.sweep(sweep_config, project="sweeps-demo-pytorch")### Run this in multiple machines/cores to distribute the hyperparameter searchwandb.agent(sweep_id, train, count=10)
