Skip to main content

Using Vertex AI and Weights & Biases to Enhance Image Resolution

In this article, we'll take a look at how to train a model that enhances image resolution using Google's Vertex AI Platform and Weights & Biases.
Created on July 14|Last edited on August 20
Researchers have been making incredible strides in techniques for enhancing image resolution, a problem generally referred to as image super-resolution. In this tutorial, we will train a residual dense network (RDN) implemented as a LightningModule.
While RDNs have been outdone in recent years by adversarial networks, diffusion models, and variational autoencoders, RDNs are pretty straightforward convolutional neural networks (CNN). You can get great results from scratch with only two hours of training on a T4 GPU.
Low-resolution image before and after being passed through our network
The purpose of this piece is both to walk through a super-resolution model training workflow but also to demonstrate deeply how to use Vertex with W&B. We'll spend more time on the latter, with special attention paid to logging and hyperparameter optimization with Sweeps.

Table of Contents



Source Code

All code for this tutorial is available here.

Weights & Biases

The document you're reading right now is a Weights & Biases Report, and we'll use Weights & Biases to log and visualize metrics, gradients, serialized models, and predictions.

Pytorch Lightning

Our network is implemented with Pytorch Lightning, which makes it really easy to port and scale our model. It also means we can use Pytorch Lightning's Weights & Biases integration to record our runs. You can learn more about the integration in their docs or by checking out some of the reports below:


Google Vertex AI

Vertex is an AI platform offered by Google. In addition to a wide array of AutoML products, Vertex AI provides an API that makes it easy to run highly customized ML workloads on GCP. In this tutorial, we'll use the Vertex Python SDK to execute GPU-accelerated training jobs on demand in the cloud.
Our network can be trained on any machine with a properly configured Python environment, but in order to run it for yourself on Vertex AI you will need:

The Div2k Dataset

In order to train our network, we will need to provide it with pairs of low and high-resolution versions of the same image. Data labeling won't be an issue since we can generate a lower-resolution version of any high-resolution image available.
The dataset we'll be using for our experiments today is the Diverse 2K High-Resolution Images (Div2k) dataset. It's a common super-resolution benchmark of 1000 high-resolution images and several low-resolution versions of each high res image generated with various methods. We will use the pre-made train/validation split of Div2k with training inputs that have been downsampled by a factor of 4 in both dimensions using bicubic interpolation.
For use with Vertex, we'll want the data in a GCS bucket. You can use the download.sh script to download the image archives and unpack them into the div2k directory. Copy div2k to a GCS bucket of your choice with gsutil or the GCS console.
You can run the register_dataset.py script to register your GCS dataset as an Artifact in Weights & Biases. The script will construct a new Artifact named div2k, store the location and metadata of our images using Artifact references, and log that Artifact to our Weights & Biases project. Use python register_dataset.py <bucket-name> 4 and your Artifact will be created.
Executing register_dataset.py to create our dataset Artifact
Why do we do this? Well, for starters, wandb will handle everything when we download the dataset––all we have to do is call wandb.use_artifact(...) and Artifact.download(). It also allows us to create a tidier interface for our training code: rather than passing in the exact location of our data every time we want to train a model, we can simply specify the dataset version, e.g., div2k:latest.
But the biggest reason we use Artifacts is actually bookkeeping. If we make any changes to our dataset, Weights & Biases will keep track of the exact dataset version used by each of our training runs, which is extremely helpful whenever you need to reproduce or debug an experiment.
The Weave panel below shows a view of the dataset I've created for my project:

div2k
Direct lineage view
Artifact - image-dataset
div2k:v2
Run - register_data
azure-forest-19
Run - register_data
rare-flower-20

Training Our Model On Vertex AI

Training Container

We'll train our model on Vertex AI using their Python SDK. Run this...
$ docker build . -t <artifact-registry-uri> # Substitute your artifact registry URI
$ docker push <artifact-registry-uri>
...in the source repository to build our training code into a Docker container and push that container to your Artifact Registry. The container we're building uses Vertex's preconfigured PyTorch GPU image as a base. The base image is about 16GB, so the first time you push or pull, it may take a while.

Training Code

We can train a model by running train.py in our container 🎉 The training script can be configured with command line arguments. For example:
$ python train.py div2k:latest --lr=0.0001 --train_batch_size=32 --nblocks=16 --nlayers=8
For all options, run this:
$ python train.py --help
Thanks to Pytorch Lightning, the training script will make use of any available GPU. There are several places in train.py and the LightningModule implementation model.py where wandb is used.

Logging Metrics With Weights & Biases

We can report metrics from inside of our training_step and validation_step using LightningModule.log. At the time of training, our LightningModule will feed the reported metrics to any loggers configured for our Trainer. Pytorch Lightning provides out-of-the-box support for logging to Weights & Biases via WandbLogger (docs). We also pass the LearningRateMonitor callback to our Trainer so that we see the value of the learning rate at each point in training.

Run set
16


Logging Gradients

wandb can log the distribution of parameters and gradients automatically for PyTorch models. All you need to do is call wandb.watch(<your-pytorch-module>). This kind of detailed logging is super helpful when you are debugging strange training behavior.



Logging Validation Predictions

To really evaluate the quality of these models, we need to look at predictions. The wandb.Image class allows us to log image tensors during training and have them rendered in Weights & Biases. Our LightningModule is written so that all validation predictions are logged in a wandb.Table that includes the high resolution, low resolution, and predicted image as well as the loss on each prediction.

step
lowres
superres
highres
l1_loss
45
14
11
91
The wandb.Table that we log needs to be constructed incrementally since we run these validation predictions in batches. We can configure this by overriding a few of Lightning's convenient hooks 🪝
Here is an outline of how it's done:
class MyLightningModule(pl.LightningModule):
def on_validation_start(...):
self.table = wandb.Table(...). # Initialize a new table

def validation_step(self, batch, index):
... # Make predictions
self.table.add_data(...) # Add predictions to existing table

def on_validation_end(...):
self.logger.experiment.log({"predictions": self.table}) # Log table
del self.table # Necessary to free up memory

Logging Model Weights

A great feature of the WandbLogger is that it will automatically checkpoint your models as Artifacts if you instantiate it with WandbLogger(log_model='all'). You can then configure how checkpoints will be scheduled with Lightning's ModelCheckpoint callback.

model-03aj6sl9
Version overview
Full Name
wandb-smle/vertex-super-resolution/model-03aj6sl9:v13
Aliases
best
latest
v13
Tags
Digest
16ebe7b756274e21d658c8aaae2c4f7c
Created By
Created At
July 15th, 2022 16:25:50
Num Consumers
0
Num Files
1
Size
136.0MB
TTL Remaining
Inactive
Upstream Artifacts
Description

Vertex AI Training

vertex_train.py will use the aiplatform SDK to run our training container on whatever hardware we want! The aiplatform.CustomContainerTrainingJob class allows you to specify a training job as a python object, which you can execute by calling CustomContainerTrainingJob.run().
For example, we can instantiate a job with...
job = aiplatform.CustomContainerTrainingJob(
display_name=<my-job-name>,
container_uri=<my-container-uri>,
command=["python3", "train.py", ...],
)
...and execute that job with this code:
job.run(
machine_type="n1-standard-8",
accelerator_type="NVIDIA_TESLA_T4",
accelerator_count=1,
environment_variables={...},
)
If we want to run our model on a bigger machine, we can just change the machine_type or accelerator_type in our script 😎 (here's a list of options).
We can also use the environment_variables argument of the run() function to pass environment variables to our container. By default, vertex_train.py will pass the values of WANDB_API_KEY WANDB_ENTITY and WANDB_PROJECT from your local environment to the container, so if you set those earlier, the wandb configuration is already done.
Running this code...
$ python vertex_train.py <gcp-project> <container-uri> <staging-bucket>
...will execute the training container with default hyperparameters and the div2k:latest artifact as our dataset. To overwrite the arguments to train.py, you can run vertex_train.py with the --args flag and pass a list of comma-delineated arguments that will be passed to train.py. For example:
$ python vertex_train.py ... --args="div2k:v2,--lr=0.0001,--nblocks=8"
The instance will take a few minutes to spin up, but after a short wait, you will see a new training run in your Weights & Biases project. Metrics and console logs will appear in real-time as training progresses, and you can even stop the run from the Weights & Biases UI if the training is going poorly. You also get live and highly detailed system metrics for each run, which is really useful for optimizing the batch size and data loader thread count.



Hyperparameter Optimization

Once you're using wandb to track your experiments, it becomes trivial to perform hyperparameter optimization (HPO) using Sweeps.

Sweep Config

The first step in executing a Sweep is to create a configuration file defining
  1. The parameters we want to search over
  2. The method we should use to guide our search (options: bayes, grid, random)
  3. The metric we should optimize
  4. The program that should be invoked for model training
  5. The command that should be used to invoke our training program
Take the configuration below, for example:
command:
- ${env}
- ${interpreter}
- ${program}
- div2k:4x
- ${args}
method: bayes
metric:
name: val/loss
goal: minimize
parameters:
log_grad:
value: 200
lr:
values:
- 0.0001
- 1e-05
- 5e-05
nblocks:
values:
- 16
- 8
nlayers:
values:
- 4
- 8
train_batch_size:
values:
- 32
- 64
program: train.py
Here, we specifically define a search in which the metric val/loss will be minimized using Bayesian hyperparameter optimization. The optimizer is free to adjust the values of nblocks, nlayers, train_batch_size, and lr to any of the values provided. Each training run will be started by the launch agent using the command:
$ env python train.py div2k:4x --log_grad=200 --nlayers=<nlayers> ... # etc
Each parameter will be passed as a command line argument to our training script, positioned inside the command using the ${args} variable. You can learn more about customizing the sweep commands here.
Once we have created a sweep config as either a YAML file or an in-memory dictionary, we can create our sweep by uploading the config to Weights & Biases. This can be done by running...
$ wandb sweep <sweep-config-yaml>
...from the command line with a YAML file or...
wandb.sweep(<sweep-config-dict>)
...with a Python dictionary.
Creating a sweep from a YAML config using the wandb sweep command

Sweep Agent

Once you have created a sweep, you can begin pulling and executing hyperparameter combinations by running an agent. You can start an agent for a particular sweep by running...
$ wandb agent <sweep-id>
...from the command line or...
wandb.agent(<sweep-id>)
...in Python.
Once an agent starts, it will poll the Weights & Biases server for a hyperparameter combination. When it receives a combination, it will execute the configured command in its environment, so you need to make sure you are running the agent in an environment where the command you have configured will actually kick off a training job.
The agent will pass the hyperparameters to your training script as command line arguments (if you have included ${args} in your command) but will also ensure that any variables in wandb.config are set with values supplied by the sweep controller when you call wandb.init in your training code.
A single agent will pull a parameter combination, run to completion, and repeat. You can run multiple agents concurrently if you want to parallelize your search.

Sweeps on Vertex AI

The controller for the sweep will run on Weights & Biases, and we can upload the config from any machine we want (or enter the config directly into the Weights & Biases UI). We can leverage Vertex AI to run the actual training by invoking our training container with the wandb agent command instead of python train.py.
The vertex_train.py script generates a new sweep from a YAML config and then executes the sweep by starting an agent on Vertex AI. Run...
$ python vertex_train.py <sweep-config-yaml> <gcp-project> <container-uri> <staging-bucket>
...to generate and run a sweep on Vertex from your local sweep config.



Conclusion

In this Report, we walked through a super-resolution training pipeline, pairing W&B with Vertex AI and running through a few tricks and tips to make both solutions work seamlessly.
Iterate on AI agents and models faster. Try Weights & Biases today.
artifact
File<(table)>
artifact