Skip to main content

3D Segmentation with MONAI and PyTorch Supercharged by Weights & Biases

A tutorial on how to use Weights & Biases with MONAI and PyTorch to accelerate your medical research
Created on August 16|Last edited on July 16
In this article, we'll look at how you can integrate Weights & Biases with MONAI and PyTorch.
If you've yet to learn about MONAI, here's a quick introduction. MONAI stands for Medical Open Network for AI. It's a PyTorch-based, open-source framework for deep learning in healthcare imaging. It allows us to create state-of-the-art, end-to-end training workflows for healthcare imaging and provides an optimized and standardized way to create and evaluate deep learning models.
For organizations looking for enterprise support, the MONAI Toolkit for NVIDIA AI Enterprise will fit the bill. NVIDIA AI Enterprise, an end-to-end, secure, cloud-native suite of AI software, accelerates the data science pipeline and streamlines the development and deployment of predictive AI models to automate essential processes and gain rapid insights from data.
Available in the cloud, data center, and at the edge, NVIDIA AI Enterprise includes enterprise support to enable organizations to solve new challenges while increasing operational efficiency. You can access the toolkit here to get up and running.
MONAI has a ton of utility for healthcare imaging, namely in domain specific networks, losses, evaluation metrics, pre-processing for multi-dimensional medical imaging data, etc. As far as our task today, we'll be using MONAI & W&B to perform 3D segmentation of spleens.
On the Weights & Biases side, we'll see how we can:
✅ Log experiment configuration.
✅ Log training and evaluation metrics.
✅ Log versioned model checkpoints with W&B artifacts.
✅ Log interactive media for seeing the slices of CT scans of medical images.
✅ Log images with ground truth segmentation masks.
✅ Log input image, ground truth and predicted segmentation masks with W&B tables (useful for error analysis)
✅ Parallel coordinates and parameter importance chart to analyze hyperparameters.
And much more...
Before reading this report, here are some of the things to take note of
Note: The Weights & Biases dashboard for this report can be found here
💡
Note: In this report, I have only the included the code relevant to this article. You can view and execute the complete code from the given colab notebook. I would highly suggest to run the code and follow along the tutorial for better understanding.
💡
Note: The training part will take a long amount of time on colab as it provides a slower T4 GPU. I have trained all the experiments on RTX 5000 (which is significantly faster) provided by spot-instances on jarvislabs.ai at a very affordable price ($0.19/hr).
💡
Here's what we'll be covering:

Table of Contents (click to expand)



Setup

If you want to follow along the tutorial you can get access to the complete code used for this article from the colab notebook:

First, let's start by installing MONAI and Weights & Biases.
!python -c "import monai" || pip install -q "monai-weekly[gdown, nibabel, tqdm, ignite]"
!python -c "import wandb" || pip install -q wandb
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline


That was easy.

Data

Download data

Next, we will download the dataset. For this article, we will be using the Spleen dataset from medicaldecathlon.com.
The spleen is a fist-sized organ in the upper left side of your abdomen, next to your stomach, and behind your left ribs. It's an important part of your immune system, largely by keeping our blood healthy.
# define the link of the dataset
resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar"
# define the hash value to validate the downloaded file
md5 = "410d4a301da4e5b2f6f86ec3ddba524e"
# define the path for downloading the .tar file
compressed_file = os.path.join(root_dir, "Task09_Spleen.tar")
# define the directory for extracting the contents of the .tar file
data_dir = os.path.join(root_dir, "Task09_Spleen")
if not os.path.exists(data_dir):
# download, extract and validate the file
download_and_extract(resource, compressed_file, root_dir, md5)

Visualize data

This data consists of CT scans of spleens. It has 61 3D volumes (41 Training + 20 Testing), with each volume having a few 100s of 2D images. We can visualize each 3D volume as a series of 2D images interactively in W&B. Furthermore, we will be also visualizing the segmentation masks on each of these 2D images. Let's see how the code to log a single 3D volume looks like:
# utility function for generating interactive image mask from components
def wb_mask(bg_img, mask):
return wandb.Image(bg_img, masks={
"ground truth" : {"mask_data" : mask, "class_labels" : {0: "background", 1: "mask"} }})

def log_spleen_slices(total_slices=100):

wandb_mask_logs = []
wandb_img_logs = []

check_ds = Dataset(data=train_files, transform=val_transforms)
check_loader = DataLoader(check_ds, batch_size=1)
check_data = first(check_loader) # get the first item of the dataloader

image, label = (check_data["image"][0][0], check_data["label"][0][0])
for img_slice_no in range(total_slices):
img = image[:, :, img_slice_no]
lbl = label[:, :, img_slice_no]
# append the image to wandb_img_list to visualize
# the slices interactively in W&B dashboard
wandb_img_logs.append(wandb.Image(img, caption=f"Slice: {img_slice_no}"))

# append the image and masks to wandb_mask_logs
# to see the masks overlayed on the original image
wandb_mask_logs.append(wb_mask(img, lbl))

wandb.log({"Image": wandb_img_logs})
wandb.log({"Segmentation mask": wandb_mask_logs})

# 🐝 init wandb with appropiate project and run name
wandb.init(project="MONAI_Spleen_3D_Segmentation", name="slice_image_exploration")
# 🐝 log images to W&B
log_spleen_slices(total_slices=100)
# 🐝 finish the run
wandb.finish()
And here's how it looks in the dashboard. In the left panel, you can slide the bar and see each of the slices of the 3D volume. On the right panel (named "Segmentation mask"), you can see the segmented ground truth masks (red color) on the 2D slices.





Configuration

Let's define a configuration in a dictionary so that we can use this throughout our notebook/script. Furthermore, we will pass this configuration dictionary in the wandb.init method while initiating the run to log it to W&B. You can even use the wandb.init method to set a project name, experiment/run name, write some notes for the experiment, and much more. To learn what you can pass to wandb.init refer the documentation here.
config = {
# data
"cache_rate": 1.0,
"num_workers": 4,

# train settings
"train_batch_size": 2,
"val_batch_size": 1,
"learning_rate": 1e-3,
"max_epochs": 100,
"val_interval": 10, # check validation score after n epochs
"lr_scheduler": "cosine_decay", # just to keep track


# Unet model (you can even use nested dictionary and this will be handled by W&B automatically)
"model_type": "unet", # just to keep track
"model_params": dict(spatial_dims=3,
in_channels=1,
out_channels=2,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
norm=Norm.BATCH,
)
}
Next, we will create the dataset and dataloaders. We also define the model, loss function, optimizer, and scheduler. For this article, we are using a 3D UNet model and loss function as DiceLoss from MONAI. We will be evaluating our models on the DiceMetric from MONAI. Dice Metric/Coefficient is a common metric used for evaluating segmentation models. You can refer this article to know more about it.
The details of the dataset, data loaders, model, loss functions, etc, can be found in the colab notebook.




Train

Logging training and validation metrics

We are ready to train our model. We will log our training and validation loss/metrics, and learning rate to W&B with a simple wandb.log() method. wandb.log() method is very powerful and can log things ranging from scalar values, histograms, plots, images, and tables to 3D objects. You can refer to the documentation to learn all the things you can log with wandb.log().
We'll also save the model with the best validation dice metric score and log the best metric value and the epoch value at which it was logged to W&B. You'll see that code below.
Note: The code commented with the 🐝 symbol is the W&B-specific code, the rest of everything is a typical PyTorch training loop. Notice how little code is required to log loss/metrics to a dashboard that you can access from anywhere in the world from a browser.
💡
# 🐝 initialize a wandb run
wandb.init(
project="MONAI_Spleen_3D_Segmentation",
config=config
)

# 🐝 log gradients of the model to wandb
wandb.watch(model, log_freq=100)

max_epochs = config['max_epochs']
val_interval = config['val_interval']
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)])
post_label = Compose([AsDiscrete(to_onehot=2)])

for epoch in range(max_epochs):
print("-" * 10)
print(f"epoch {epoch + 1}/{max_epochs}")
model.train()
epoch_loss = 0
step = 0
for batch_data in train_loader:
step += 1
inputs, labels = (
batch_data["image"].to(device),
batch_data["label"].to(device),
)
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
print(
f"{step}/{len(train_ds) // train_loader.batch_size}, "
f"train_loss: {loss.item():.4f}")
# 🐝 log train_loss for each step to wandb
wandb.log({"train/loss": loss.item()})
epoch_loss /= step
epoch_loss_values.append(epoch_loss)
print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
# step scheduler after each epoch (cosine decay)
scheduler.step()
# 🐝 log train_loss averaged over epoch to wandb
wandb.log({"train/loss_epoch": epoch_loss})
# 🐝 log learning rate after each epoch to wandb
wandb.log({"learning_rate": scheduler.get_lr()[0]})

if (epoch + 1) % val_interval == 0:
model.eval()
with torch.no_grad():
for val_data in val_loader:
val_inputs, val_labels = (
val_data["image"].to(device),
val_data["label"].to(device),
)
roi_size = (160, 160, 160)
sw_batch_size = 4
val_outputs = sliding_window_inference(
val_inputs, roi_size, sw_batch_size, model)
val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
val_labels = [post_label(i) for i in decollate_batch(val_labels)]
# compute metric for current iteration
dice_metric(y_pred=val_outputs, y=val_labels)

# 🐝 aggregate the final mean dice result
metric = dice_metric.aggregate().item()

# 🐝 log validation dice score for each validation round
wandb.log({"val/dice_metric": metric})

# reset the status for next validation round
dice_metric.reset()

metric_values.append(metric)
if metric > best_metric:
best_metric = metric
best_metric_epoch = epoch + 1
torch.save(model.state_dict(), os.path.join(
root_dir, "best_metric_model.pth"))
print("saved new best metric model")
print(
f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
f"\nbest mean dice: {best_metric:.4f} "
f"at epoch: {best_metric_epoch}"
)
print(
f"\ntrain completed, best_metric: {best_metric:.4f} "
f"at epoch: {best_metric_epoch}")

# 🐝 log best score and epoch number to wandb
wandb.log({"best_dice_metric": best_metric, "best_metric_epoch": best_metric_epoch})
You can see how the loss and metrics change during the training and validation sub-routines and you can compare your experiments very easily.

Run set
15



Logging the model gradients

You might have noticed the following line in the above code block:
# 🐝 log gradients of the model to wandb
wandb.watch(model, log_freq=100)
This logs the gradients of each layer of the model to W&B in the form of a histogram over the complete training period. We can identify any problems (vanishing or exploding gradients) by reviewing these plots.
Here's a gradient chart from one of the experiments:

Run: kind-dawn-65
1


System Metrics

You can also view the system metrics like CPU/GPU utilization, GPU Temperature, network traffic, etc.

Run set
15




Versioning Models

You can even version your models and datasets just like we version code in GitHub. W&B artifacts allow you to version your models and datasets. To learn more about W&B artifacts, you can refer to the documentation.
Here, we will log the best model with just a couple of lines of code.
# 🐝 Version your model
best_model_path = os.path.join(root_dir, "best_metric_model.pth")
model_artifact = wandb.Artifact(
"unet", type="model",
description="Unet for 3D Segmentation of spleen",
metadata=dict(config['model_params']))
model_artifact.add_file(best_model_path)
wandb.log_artifact(model_artifact)
Notice that model config parameters are passed to metadata argument in the wandb.Artifact() method. This allows us to have more meaningful versions of the model, as we can know the model parameters used to train the model. You can pass anything you like (for example, a score of the model on the validation data) or any information you think is important for versioning the model. Here's what it looks like in the W&B dashboard:
W&B Model Registry UI. The checkpoints are shown as different versions on the left-hand side pane. You can download the model from the Files tab or use API to download it programmatically.



Viewing Logged Experiments

Since we have logged the configuration by passing it to wandb.init method, you can compare all the experiments in the table format conveniently.

Run set
15




Logging Model Predictions

One of the most important parts of any Machine Learning project is error analysis. It becomes even more important in the medical field. W&B helps you with error analysis. For example, you can log the input image, ground truth segmentation mask, and the predicted segmentation mask by your model to know where your model is making mistakes. Once you identify it you can improve the model to handle the mistaken samples effectively. W&B tables allow you to do this efficiently. To learn more about W&B Tables, you can read this report.
We will log only a subset of slices of each 3D image volume. Let's look at the code and the table created on the W&B dashboard.
PS: You can also log the metric score for each entry/image in the table to be more specific in your error analysis. This is left as an exercise for the reader
# 🐝 create a wandb table to log input image, ground_truth masks and predictions
columns = ["filename", "image", "ground_truth", "prediction"]
table = wandb.Table(columns=columns)

model.load_state_dict(torch.load(
os.path.join(root_dir, "best_metric_model.pth")))
model.eval()
with torch.no_grad():
for i, val_data in enumerate(val_loader):
# get the filename of the current image
fn = val_data['image_meta_dict']['filename_or_obj'][0].split("/")[-1].split(".")[0]

roi_size = (160, 160, 160)
sw_batch_size = 4
val_outputs = sliding_window_inference(
val_data["image"].to(device), roi_size, sw_batch_size, model
)

# log last 20 slices of each 3D image
for slice_no in range(80, 100):
img = val_data["image"][0, 0, :, :, slice_no]
label = val_data["label"][0, 0, :, :, slice_no]
prediction = torch.argmax(
val_outputs, dim=1).detach().cpu()[0, :, :, slice_no]

# 🐝 Add data to wandb table dynamically
table.add_data(fn, wandb.Image(img), wandb.Image(label), wandb.Image(prediction))

# log predictions table to wandb with `val_predictions` as key
wandb.log({"val_predictions": table})

# 🐝 Close your wandb run
wandb.finish()


Run set
15

You can see that our model is pretty good at segmenting spleen by now.
You can also compare the predictions between different experiments. Here, I am comparing two recent experiments. 0: before each filename indicates the "0" indexed run named as kind-dawn-65 and 1: before each filename indicates "1" indexed run named as leafy-dawn-64. You can compare n number of experiments (more than 2) in the table.

Compare tables
2





Parallel Coordinates Chart

Parallel coordinates chart summarizes the relationship between your hyperparameters and model metrics at a glance. You can include any hyperparameters logged to W&B in the chart to see the relationship between different hyperparameter and metrics.

Run set
15


Parameter Importance

This panel tells which of your hyperparameters were the best predictors of and highly correlated to desirable values of your metrics. You can learn more about it here.

Run set
15


Summary

In this article, we saw how we can use W&B to aid your research in the medical field and to get more out of your workflow. Features like tables allow you to perform error analysis on the model. We also saw how we can version our models. Lastly, we saw how parallel coordinates charts and parameter importance charts can be helpful in deciding on your hyperparameters.
To see the full suite of W&B features check out this 5 minutes guide. You can learn all the features in-depth from W&B guides and YouTube channel.
Check out other reports on Fully Connected covering a wide range of topics.
Vincent Ochs
Vincent Ochs •  
Good job on implementing this task! Do you know if there is something similar regarding Task 07 - Pancreas Tumour?
Reply
Seonok Kim
Seonok Kim •  
Great Work!
Reply
Iterate on AI agents and models faster. Try Weights & Biases today.