Skip to main content

Implementing NeRF in JAX

This article uses JAX to create a minimal implementation of 3D volumetric rendering of scenes represented by Neural Radiance Fields, using W&B to track all metrics.
Created on April 7|Last edited on February 15
In this article, we attempt to create a minimal implementation of 3D volumetric rendering of scenes represented by Neural Radiance Fields (NeRF)using JAX. The ideas used in this implementation were proposed by the paper NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis, where the authors propose a method that achieves state-of-the-art results for synthesizing novel views of complex scenes by optimizing an underlying continuous volumetric scene function using a sparse set of input views.
We train a model to learn the NeRF representation of a simple 3D scene using JAX and Flax on Google Cloud TPUs.
First, let's look at what we'll be covering:

Table of Contents





What is JAX?

JAX is an accelerated computation framework that brings together Autograd and XLA for high-performance machine learning research. It provides a simple NumPy and SciPy-like interface for writing code, compiled by the XLA Compiler to run on CPU, GPU, and TPU.



JAX also supports the just-in-time (JIT) compilation of Python functions into XLA-optimized kernels using a one-function API. Due to its functional programming paradigm, Jax allows us to use composable transforms to transform a function without modifying it.

What is Flax?

Flax is a neural network library and ecosystem for JAX, designed for flexibility. It provides us with a neural network API flax.linen that makes it easy for us to create models in a pythonic manner.
Additionally, Flax provides utilities and patterns for replicated training, serialization and checkpointing, metrics, and prefetching on the device.


Jumping into the Code

The code that we'll demonstrate in this report works with Google Cloud TPUs. You can run this code for free on Google Colab and Kaggle by setting the accelerator to TPU in both cases. Here are those links:


Dependencies

!pip3 install -q -U jax jaxlib flax optax imageio-ffmpeg wandb
Other than JAX, Jaxlib, and Flax, we also install the following dependencies:
  • Optax: a gradient processing and optimization library for JAX. Optax is designed to facilitate research by providing building blocks that can be easily recombined in custom ways.
  • ImageIO-FFMPEG: It provides a simple and reliable ffmpeg wrapper for working with video files.
  • Weights & Biases for tracking all metrics and losses of our experiments in real-time and creating rich, interactive dashboards.
Now, let's import all the necessary libraries:
import os
import time
import wandb
import imageio
import requests
from typing import Any
import ipywidgets as widgets
from functools import partial
from tqdm.notebook import tqdm
from kaggle_secrets import UserSecretsClient

import jax
import flax
import optax
from jax import lax
import flax.linen as nn
from flax.training import train_state, common_utils

import numpy as np
import jax.numpy as jnp

from base64 import b64encode
from IPython.display import HTML
import plotly.express as px
from plotly.subplots import make_subplots
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid


Setting up the TPU for JAX

Tensor Processing Units (or TPUs) are hardware accelerators specializing in deep learning tasks. They were created by Google and have been behind many cutting-edge results in machine learning research. For a detailed guide to using TPUs on Kaggle Notebooks, you can refer to https://www.kaggle.com/docs/tpu.
Let's set them up:

# Detect if Kaggle Notebook has access to TPUs or not
if 'TPU_NAME' in os.environ:
import requests
if 'TPU_DRIVER_MODE' not in globals():
url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
resp = requests.post(url)
TPU_DRIVER_MODE = 1
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
print("TPU DETECTED!")
print('Registered TPU:', config.FLAGS.jax_backend_target)

# Detect if Google Colab Notebook has access to TPUs or not
elif "COLAB_TPU_ADDR" in os.environ:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

else:
print('No TPU detected.')

DEVICE_COUNT = len(jax.local_devices())
TPU = DEVICE_COUNT==8

if TPU:
print("8 cores of TPU ( Local devices in Jax ):")
print('\n'.join(map(str,jax.local_devices())))

Setting Up Our Weights & Biases Run

We can now call wandb.init to initialize a new job. This creates a new run in Weights & Biases and launches a background process to sync data. We sync all the configs of our experiments with the W&B run, which makes it far easier for us to reproduce the results of the experiment later.
wandb.login()
wandb.init(project="nerf-jax", job_type="train")

# sync all experiment configs with Weights & Biases
config = wandb.config
config.near_bound = 2. # Near Bound of sample space for 3d points
config.far_bound = 6. # Far Bound of sample space for 3d points
config.batch_size = int(1e4) # Batch Size
config.num_sample_points = 256 # Number of points to be sampled across the volume
config.epsilon = 1e10 # Hyperparameter for volume rendering
config.apply_positional_encoding = True # Apply posittional encoding to the input or not
config.positional_encoding_dims = 6 # Number of positional encodings applied
config.num_dense_layers = 8 # Number of dense layers in MLP
config.dense_layer_width = 256 # Dimentionality of dense layers' output space
config.learning_rate = 5e-4 # Learning Rate
config.train_epochs = 1000 # Number of training epochs
config.plot_interval = 100 # Epoch interval for plotting results during training
We'll discuss what each of these configs represents in the later sections of this report.



What is a NeRF (Neural Radiance Field)?

A NeRF (Neural Radiance Field) is a neural network that non-discretely models a point and viewing direction in 3D space to the amount of light emitted by that point in each direction. It allows us to synthesize novel (new) views of complex scenes.
To clarify, let's break it down into its component parts:
  • The word neural obviously means that there's a Neural Network involved
  • Radiance refers to the radiance of the scene that the Neural Network outputs. It is basically describing how much light is being emitted by a point in space in each direction, and
  • The word Field means that the Neural Network models a continuous and non-discretized representation of the scene it learns.
Figure: Overview of the Model
  • The Neural Network FθF_{\theta} in this case, is a simple Multi-layered Perceptron or MLP with ReLU activation.
  • The MLP consists of 9 fully-connected layers of width 256.
  • The Input to the MLP consists of 2 components:
    • (x,y,z)(x, y, z) which denotes the spatial position of a given point in 3D space.
    • (θ,ϕ)(\theta, \phi) denotes a given viewing direction from the point.
    • (x,y,z,θ,ϕ)(x, y, z, \theta, \phi) forms a single continuous 5D coordinate which is fed to the MLP.
  • The output to the MLP consists of 2 components as well:
    • (r,g,b)(r, g, b) which denotes the view composed from the point (x,y,z)(x, y, z) along the direction (θ,ϕ)(\theta, \phi) in RGB colorspace.
    • σ\sigma denotes the density or transparency of the point.
    • The value of σ\sigma lies in the range [0,)[0, \infty).
    • A σ\sigma value of 00  means there is nothing at the point or the point is transparent and A σ\sigma value of \infty  means that the point is opaque.

Implementing the Neural Network in Flax

We'll use the flax.linen API to create the neural network FθF_{\theta}. So what is linen exactly? Simple: linen is a neural network API developed as part of the Flax package. It builds on a functional core, enabling direct usage of JAX transformations such as vmap, remat, or scanning inside our modules.
We'll create our NeRFModel as a subclass of the flas.linen.Module similar to how we would use torch.nn.Module API for PyTorch or tf.keras.Model API for TensorFlow.
class NeRFModel(nn.Module):
# dtype of the computation
dtype: Any = jnp.float32
# numerical precision of the computation (set to bfloat16 by default)
precision: Any = lax.Precision.DEFAULT
# Apply positional ecoding or not
apply_positional_encoding: bool = config.apply_positional_encoding

@nn.compact
def __call__(self, input_points):
# Apply positional encoding to the input points
x = positional_encoding(input_points) if self.apply_positional_encoding else input_points
for i in range(config.num_dense_layers):
# Fully-connected layer
x = nn.Dense(
config.dense_layer_width,
dtype=self.dtype,
precision=self.precision
)(x)
# ReLU activation function
x = nn.relu(x)
# Skip connection
x = jnp.concatenate([x, input_points], axis=-1) if i == 4 else x
# Output consists of 4 values: (r, g, b, sigma)
x = nn.Dense(4, dtype=self.dtype, precision=self.precision)(x)
return x



Positional Encodings

Let's step back from our NeRFModel and consider a simpler problem to act as a proxy for NeRF:
Make a neural network memorize a single image.
The model takes the input of a set of pixel coordinates (x,y)(x, y) and predicts the value of the pixel. We train this model by overfitting it on a single image with the object of memorizing that single image accurately. Note that we consider the model to be an MLP with a ReLU activation function similar to our NeRFModel.
The surprising thing is that such a model contains 8 to 10 times more parameters than pixels in the image, but it still does a rather poor job, even when given an extremely generous amount of training time.
It's almost as if we can see the ReLU activation in the attempted reconstruction 😱

What's The Solution?

Thankfully, the NeRF authors found a solution: they observed that if we take the sinusoids of the input coordinates with increasing frequency and expand the feature vector out into a high-dimensional space and feed all the concatenated feature vectors into the MLP instead of the coordinates directly, the model is able to memorize the image!!!
Although the idea of Fourier Feature Mapping was presented in the original NeRF paper, we're focusing on. The authors didn't provide any explanation as to why it worked. The explanation was provided in a later work titled Fourier Features Let Networks Learn High-Frequency Functions in Low-Dimensional Domains.
In fact, the explanation is right there in the title. Basically, passing input points through a simple Fourier Feature Mapping enables an MLP to learn high-frequency functions (such as an RGB image) in low-dimensional problem domains (such as a 2D coordinate of pixels).
This pre-processing step of encoding the input is called Fourier Feature Mapping or Positional Encoding. This encoding γ(v)\gamma(v) can be represented as
γ(v)=(sin(v),cos(v)sin(2v),cos(2v)sin(4v),cos(4v)...sin(2(L1)v),cos(2(L1)v))\Large{\gamma(v) =\begin{pmatrix} sin(v), cos(v)\\ sin(2v), cos(2v)\\ sin(4v), cos(4v)\\ ...\\ sin(2^{(L-1)}v), cos(2^{(L-1)}v)\\ \end{pmatrix}}

or,

γ(v)=[sin(2πBv),cos(2πBv)]T\Large{\gamma(v) = [sin(2\pi Bv), cos(2\pi Bv)]^{T}}


As we can see that the MLP with Fourier features is able to memorize the image quickly 🚀

Let's check what the Fourier Feature Mapping looks like in code:
def positional_encoding(inputs):
batch_size, _ = inputs.shape
# Applying vmap transform to vectorize the multiplication operation
inputs_freq = jax.vmap(
lambda x: inputs * 2.0 ** x
)(jnp.arange(config.positional_encoding_dims))
periodic_fns = jnp.stack([jnp.sin(inputs_freq), jnp.cos(inputs_freq)])
periodic_fns = periodic_fns.swapaxes(0, 2).reshape([batch_size, -1])
periodic_fns = jnp.concatenate([inputs, periodic_fns], axis=-1)
return periodic_fns
Note that we apply the composable transform jax.vmap on the multiplication operation, which "vectorizes" the operation. What this means is that it allows us to compute the output of a function in parallel over some axis of the input.



View Synthesis

Before we get into discussing view synthesis, let's grab our dataset and take a look at it.
We'll be using the Tiny-NeRF dataset, which is a subset of the original Blender dataset that was used by the authors. We would use the first 100 images and their respective poses as training data. For validation, we use a single image, pose pair.
The small validation subset is not going to be an issue in our case since we essentially want our network to "memorize" the 3D scene given by the data in a way that enables smooth interpolation between frames.
!wget https://people.eecs.berkeley.edu/~bmild/nerf/tiny_nerf_data.npz

data = np.load("tiny_nerf_data.npz")
images = data["images"]
poses = data["poses"]
focal = float(data["focal"])

_, image_height, image_width, _ = images.shape

# We would use the first 100 images, poses as training data
train_images, train_poses = images[:100], poses[:100]
# We use a single image, pose pair for validation
val_image, val_pose = images[101], poses[101]


# Visualize the training data
fig = plt.figure(figsize=(16, 16))
grid = ImageGrid(fig, 111, nrows_ncols=(4, 4), axes_pad=0.1)
random_images = images[np.random.choice(np.arange(images.shape[0]), 16)]
for ax, image in zip(grid, random_images):
ax.imshow(image)
plt.title("Sample Images from Tiny-NeRF Data")
plt.show()


The dataset basically consists of a set of images of this lego model taken from multiple poses, now the task at hand is to learn a representation of this object such that we can synthesize the view of the object from any given pose, even ones that are not present in the dataset.
An example of 3D images collected from various camera poses along a spherical surface.

View Synthesis using Volume Rendering

Now that we have a dense sampling of views represented by an infinitely high-resolution field of radiance that's defined by the MLP NeRFModel, photorealistic novel views can be reconstructed by simple light field sample interpolation techniques.
For synthesizing novel views, the authors of NeRF propose using volumetric representations to address the task of high-quality photorealistic view synthesis from a set of input RGB images. Approaches are able to realistically represent complex shapes and materials, are well-suited for gradient-based optimization, and tend to produce less visually distracting artifacts than mesh-based methods.
We would render the color of any ray passing through the scene using principles from classical volume rendering. The expected color C(r)C(r) of a ray from the camera given by r(t)r(t) with near bound tnt_{n} and far bound tft_{f} is given by
C(r)=tntfT(t)σ(r(t)c(r(t)))\Large{C(r) = \int_{t_{n}}^{t_{f}}T(t)\sigma(r(t)c(r(t)))}

where...
  • σ(r(t))\sigma(r(t)) is the volume density predicted by the MLP that can be interpreted as the differential probability of a ray terminating at an infinitesimal particle at the location r(t)r(t). It is basically the opacity of the given point.
  • c(r(t))c(r(t)) is the RGB color predicted by the MLP that can be interpreted as the radiance of a ray terminating at the location r(t)r(t).
  • T(t)T(t) denotes the accumulated transmittance along the ray r(t)r(t) from near bound tnt_{n} to far bound tft_{f} i.e., the probability that the ray travels from tnt_{n} to t without hitting any other particle.


In order to perform volume rendering, we have to perform the following steps:
Generate Rays: Given a pose, we have to generate a grid of rays corresponding to the dimensions of the image that we desire to render. This is done by the following function:
def generate_rays(height, width, focal, pose):
# Create a 2D rectangular grid for the rays corresponding to image dimensions
i, j = np.meshgrid(np.arange(width), np.arange(height), indexing="xy")
transformed_i = (i - width * 0.5) / focal # Normalize the x-axis coordinates
transformed_j = -(j - height * 0.5) / focal # Normalize the y-axis coordinates
k = -np.ones_like(i) # z-axis coordinates
# Create the unit vectors corresponding to ray directions
directions = np.stack([transformed_i, transformed_j, k], axis=-1)
# Compute Origins and Directions for each ray
camera_directions = directions[..., None, :] * pose[:3, :3]
ray_directions = np.einsum("ijl,kl", directions, pose[:3, :3])
ray_origins = np.broadcast_to(pose[:3, -1], ray_directions.shape)
return np.stack([ray_origins, ray_directions])

Compute 3D query points tt between the near bound tnt_{n} and far bound tft_{f}. We would use a stratified sampling approach where we partition [tn,tf][t_{n}, t_{f}] into NN evenly-spaced bins and then draw one sample uniformly at random from within each bin. This action is performed by the following function:
def compute_3d_points(ray_origins, ray_directions, random_number_generator=None):
"""Compute 3d query points for volumetric rendering"""
# Sample space to parametrically compute the ray points
t_vals = np.linspace(config.near_bound, config.far_bound, config.num_sample_points)
if random_number_generator is not None:
# inject a uniform noise into the sample space to make it continuous
t_shape = ray_origins.shape[:-1] + (config.num_sample_points,)
noise = jax.random.uniform(
random_number_generator, t_shape
) * (config.far_bound - config.near_bound) / config.num_sample_points
t_vals = t_vals + noise
# Compute the ray traversal points using r(t) = o + t * d
ray_origins = ray_origins[..., None, :]
ray_directions = ray_directions[..., None, :]
t_vals_flat = t_vals[..., :, None]
points = ray_origins + ray_directions * t_vals_flat
return points, t_vals
Note that NN is given by config.num_sample_points in our code.
💡

Compute the Radiance Field: We get the colors c(r(t))c(r(t)) and opacities σ(r(t))\sigma(r(t)) from the NeRFModel. This action is performed by the following function:
def compute_radiance_field(model, points):
"""Compute Radiance Field"""
# Perform forward propagation
model_output = lax.map(model, jnp.reshape(points, [-1, config.batch_size, 3]))
radiance_field = jnp.reshape(model_output, points.shape[:-1] + (4,))
# Slice the model output
opacities = nn.relu(radiance_field[..., 3])
colors = nn.sigmoid(radiance_field[..., :3])
return opacities, colors

Compute Adjacent Distances δi=ti+1ti\delta_{i} = t_{i + 1} - t_{i} between adjacent intervals along the sample space. This action is performed by the following functions:
def compute_adjacent_distances(t_vals, ray_directions):
"""Get distances between adjacent intervals along sample space"""
distances = t_vals[..., 1:] - t_vals[..., :-1]
distances = jnp.concatenate([
distances, np.broadcast_to(
[config.epsilon], distances[..., :1].shape
)], axis=-1
)
# Multiply each distance by the norm of its corresponding direction ray
# to convert to real world distance (accounts for non-unit directions)
distances = distances * jnp.linalg.norm(ray_directions[..., None, :], axis=-1)
return distances

Compute Transmittance: The accumulated transmittance along the ray r(t)r(t) is given by Ti=exp(i=1Nσiδi)T_{i} = exp(-\sum_{i=1}^{N}\sigma_{i}\delta_{i}) where δi\delta_{i} is the distances between adjacent intervals along the sample space. This action is performed by the following functions:
def compute_weights(opacities, distances):
"""Compute weight for the RGB of each sample along each ray"""
# Compute density from the opacity
density = jnp.exp(-opacities * distances)
alpha = 1.0 - density
clipped_difference = jnp.clip(1.0 - alpha, 1e-10, 1.0)
# A cumulative product is basically used to express the idea
# of the ray not having reflected up to this sample yet
transmittance = jnp.cumprod(
jnp.concatenate([
jnp.ones_like(clipped_difference[..., :1]),
clipped_difference[..., :-1]], -1
), axis=-1
)
return alpha * transmittance

Now, all we need to do is bring all the steps together to perform volume rendering and synthesize the desired view.
def perform_volume_rendering(model, ray_origins, ray_directions, random_number_generator=None):
# Compute 3d query points
points, t_vals = compute_3d_points(
ray_origins, ray_directions, random_number_generator
)

# Get color and opacities from the model
opacities, colors = compute_radiance_field(model, points)

# Get distances between adjacent intervals along sample space
distances = compute_adjacent_distances(t_vals, ray_directions)

# Compute weight for the RGB of each sample along each ray
weights = compute_weights(opacities, distances)

# Compute weighted RGB color of each sample along each ray
rgb_map = jnp.sum(weights[..., None] * colors, axis=-2)
# Compute the estimated depth map
depth_map = jnp.sum(weights * t_vals, axis=-1)
# Sum of weights along each ray; the value is in [0, 1] up to numerical error
acc_map = jnp.sum(weights, axis=-1)
# Disparity map is basically the inverse of depth
disparity_map = 1. / jnp.maximum(1e-10, depth_map / jnp.sum(weights, axis=-1))

return rgb_map, depth_map, acc_map, disparity_map, opacities
The Tiny NeRF dataset doesn not include view directions, hence the inputs used in this report are 3D instead of 5D as presented in the paper.
💡



Training The Model

Initializing The Model

First, the code.
def initialize_model(key, input_pts_shape):
# Create an instance of the model
model = NeRFModel()

# Initialize the model parameters
initial_params = jax.jit(model.init)(
{"params": key},
jnp.ones(input_pts_shape),
)
return model, initial_params["params"]
Note that we use the jax.jit transform on the model initialization operation for the model to be compiled to the XLA compile in time. This makes it more effective to run on an accelerator which is a TPU in our case. Compiling a function also avoids the overhead of the Python interpreter, which gives you a speedup irrespective of the accelerator we use.

Implementing The Train and Validation Steps

The train and validation steps are stateless functions that perform a single iteration of training and validation each.

The Train Step

  • We write the train_step as a function of:
    • the train state, which is a combination of the model state and the optimizer state
    • the input rays
    • the target images
  • We use Mean-squared Error as the loss function.
  • We use the composable transform jax.value_and_grad to compute both the loss and the gradients.
  • We use Peak Signal-to-Noise Ratio as our metric, which is given by 20log10(MAXfMSE)20log_{10}(\frac{MAX_{f}}{\sqrt{MSE}})
  • We apply the composable transform jax.pmap on train_step. This will compile the function with XLA similarly to jax.jit, then execute it in parallel on XLA devices, in our case, on multiple TPU cores.
Let's take a look at the train_step function...
def train_step(state, batch, rng):
"""Train Step"""
# Unravel the inputs and targets from the batch
inputs, targets = batch

# Compute the loss in a stateless manner
def loss_fn(params):
# Create the model function from the train state
model_fn = lambda x: state.apply_fn({"params": params}, x)
# Unravel the input rays
ray_origins, ray_directions = inputs
# Get the RGB view by performing volume rendering
rgb, *_ = perform_volume_rendering(
model_fn, ray_origins, ray_directions, rng
)
# Compute mean-squared error
return jnp.mean((rgb - targets) ** 2)

# Transform the loss function to get the loss value and the gradients
train_loss, gradients = jax.value_and_grad(loss_fn)(state.params)
# Compute all-reduce mean on gradients over the pmapped axis
gradients = lax.pmean(gradients, axis_name="batch")
# Updated the model params and the optimizer state
new_state = state.apply_gradients(grads=gradients)
# Mean of train loss of the batch
train_loss = jnp.mean(train_loss)
# Compute PSNR
train_psnr = -10.0 * jnp.log(train_loss) / jnp.log(10.0)
return train_loss, train_psnr, new_state


# Apply the transform jax.pmap on the train_step to parallelize it on XLA devices
parallelized_train_step = jax.pmap(train_step, axis_name="batch")
Note that the transform pmap is different from vmap regarding how they perform the computation. While vmap vectorizes a function by adding a batch dimension to every primitive operation in the function, pmap replicates the function and executes each replica on its own XLA device in parallel.
💡

The Validation Step

The validation_step is almost the same as the train_step function, with just two differences:
  • Since we use just a single image for validation, we will apply the jax.jit transform to the validation_step instead of jax.pmap in order to compile it for the accelerators.
  • There would be no need to compute the gradients and no change in the model and optimizer state in the validation_step.
@jax.jit
def validation_step(state):
"""Validation Step"""
# Create the model function from the state
model_fn = lambda x: state.apply_fn({"params": state.params}, x)
# Unravel the validation rays, which is a global
ray_origins, ray_directions = val_rays
# Get the rendered views by performing volume rendering
rgb, depth, *_ = perform_volume_rendering(
model_fn, ray_origins, ray_directions
)
# Compute Mean-squared error with the validation image, which is a global
loss = jnp.mean((rgb - val_image) ** 2)
# Compute the PSNR
psnr = -10.0 * jnp.log(loss) / jnp.log(10.0)
return rgb, depth, psnr, loss
Note that function transforms are simply callables that return a transformed function, hence we can use then as decorators as well.
💡

The Training and Validation Loop

Now let us bring everything together and write the training loop...
# Create the train and validation rays as globals
train_rays = np.stack(list(map(
lambda x: generate_rays(image_height, image_width, focal, x), train_poses
)))
val_rays = generate_rays(image_height, image_width, focal, val_pose)

# Number of accelerators
n_devices = jax.local_device_count()

# Random Number Generator
key, rng = jax.random.split(jax.random.PRNGKey(0))

# Initialize the Model
model, params = initialize_model(key, (image_height * image_width, 3))

# Define the Optimizer
optimizer = optax.adam(learning_rate=config.learning_rate)

# Create the Training State
state = train_state.TrainState.create(
apply_fn=model.apply, params=params, tx=optimizer
)

# Transfer arrays in the state to the specified devices and form ShardedDeviceArrays
state = jax.device_put_replicated(state, jax.local_devices())

# Function for executing the train and validation loop
def train_and_evaluate(state, train_step_fn, validation_step_fn):
train_loss_history, train_psnr_history = [], []
val_loss_history, val_psnr_history = [], []
for epoch in tqdm(range(config.train_epochs)):
# Shard Random Number Generators
rng_index, rng_epoch = jax.random.split(jax.random.fold_in(rng, epoch))
sharded_rngs = common_utils.shard_prng_key(rng_epoch)
# Create the Train Batch
train_index = jax.random.randint(
rng_index, (n_devices,), minval=0, maxval=len(train_rays)
)
train_batch = train_rays[tuple(train_index), ...], train_images[tuple(train_index), ...]
# Perform the Training Step
train_loss, train_psnr, state = train_step_fn(state, train_batch, sharded_rngs)
train_loss_history.append(np.asarray(np.mean(train_loss)))
train_psnr_history.append(np.asarray(np.mean(train_psnr)))
wandb.log({"Train Loss": np.asarray(np.mean(train_loss))}, step=epoch)
wandb.log({"Train PSNR": np.asarray(np.mean(train_loss))}, step=epoch)
# Perform the Validation Step
validation_state = flax.jax_utils.unreplicate(state)
rgb, depth, val_psnr, val_loss = validation_step_fn(validation_state)
val_loss_history.append(np.asarray(val_loss))
val_psnr_history.append(np.asarray(val_psnr))
wandb.log({"Validation Loss": np.asarray(val_loss)}, step=epoch)
wandb.log({"Validation PSNR": np.asarray(val_psnr)}, step=epoch)
# Plot the result every plot interval
if epoch % config.plot_interval == 0:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
ax1.imshow(rgb)
ax1.set_title(f"Predicted RGB at Epoch {epoch}")
ax1.axis("off")
ax2.imshow(depth)
ax2.set_title(f"Predicted Depth at Epoch {epoch}")
ax2.axis("off")
plt.show()
inference_state = flax.jax_utils.unreplicate(state)
history = {
"train_loss": train_loss_history,
"train_psnr": train_psnr_history,
"val_loss": val_loss_history,
"val_psnr": val_psnr_history
}
return state, inference_state, history
Note that we use jax.device_put_replicated to transfer the arrays in the state on multiple TPU devices and form Sharded Device Arrays. Every array in the pytree is replicated into the devices which makes the same data replicated across all local devices.
💡
Now, let us perform the training...
%%wandb

state, inference_state, history = train_and_evaluate(
state, parallelized_train_step, validation_step
)
After running wandb.init(), if we use %%wandb at the beginning of a new cell in our notebook, we can track the runs live from our notebook. For more information on tracking Weight & Biases runs from jupyter notebook, you can refer to the official guide.
💡


Experiment Tracking
3




Rendering The Scene

Now that our model has learned the representation of our model, we should be able to reconstruct the scene from any given pose.
# Translating Matrix for movement in t
def get_translation_matrix(t):
return np.asarray(
[
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, t],
[0, 0, 0, 1],
]
)

# Rotation Matrix for movement in phi
def get_rotation_matrix_phi(phi):
return np.asarray(
[
[1, 0, 0, 0],
[0, np.cos(phi), -np.sin(phi), 0],
[0, np.sin(phi), np.cos(phi), 0],
[0, 0, 0, 1],
]
)

# Rotation Matrix for moevment in theta
def get_rotation_matrix_theta(theta):
return np.asarray(
[
[np.cos(theta), 0, -np.sin(theta), 0],
[0, 1, 0, 0],
[np.sin(theta), 0, np.cos(theta), 0],
[0, 0, 0, 1],
]
)

# Create the camera to world coordinate transform matrix
def pose_spherical(theta, phi, radius):
camera_to_world_transform = get_translation_matrix(radius)
camera_to_world_transform = get_rotation_matrix_phi(phi / 180.0 * np.pi) @ camera_to_world_transform
camera_to_world_transform = get_rotation_matrix_theta(theta / 180.0 * np.pi) @ camera_to_world_transform
camera_to_world_transform = np.array([
[-1, 0, 0, 0],
[0, 0, 1, 0],
[0, 1, 0, 0],
[0, 0, 0, 1]
]) @ camera_to_world_transform
return camera_to_world_transform

# Function to render the scene from the rays
@jax.jit
def get_renderings(rays):
model_fn = lambda x: inference_state.apply_fn(
{"params": inference_state.params}, x
)
ray_origins, ray_directions = rays
rgb, depth, acc, disparity, opacities = perform_volume_rendering(
model_fn, ray_origins, ray_directions
)
rgb = (255 * jnp.clip(rgb, 0, 1)).astype(jnp.uint8)
return rgb, depth, acc, disparity, opacities

# Create a 360 degree video of the 3D scene
def get_frames():
video_angle = jnp.linspace(0.0, 360.0, 120, endpoint=False)
camera_to_world_transform = map(lambda th: pose_spherical(th, -30.0, 4.0), video_angle)
rays = np.stack(list(map(
lambda x: generate_rays(
image_height, image_width, focal, x[:3, :4]
), camera_to_world_transform
)))
rgb_frames, depth_frames, acc_maps, disparity_maps, opacities = lax.map(get_renderings, rays)
rgb_frames = np.asarray(rgb_frames)
depth_frames = np.asarray(depth_frames)
acc_maps = np.asarray(acc_maps * 255.)
disparity_maps = np.asarray(disparity_maps * 255.)
return rgb_frames, depth_frames, acc_maps, disparity_maps


rgb_frames, depth_frames, acc_maps, disparity_maps = get_frames()

imageio.mimwrite("rgb_video.mp4", tuple(rgb_frames), fps=30, quality=7)
imageio.mimwrite("depth_video.mp4", tuple(depth_frames), fps=30, quality=7)
imageio.mimwrite("acc_video.mp4", tuple(acc_maps), fps=30, quality=7)
imageio.mimwrite("disparity_video.mp4", tuple(disparity_maps), fps=30, quality=7)

wandb.log({"RGB Rendering": wandb.Video("rgb_video.mp4", fps=30, format="gif")})
wandb.log({"Depth Rendering": wandb.Video("depth_video.mp4", fps=30, format="gif")})
wandb.log({"Accuracy Rendering": wandb.Video("acc_video.mp4", fps=30, format="gif")})
wandb.log({"Disparity Map Rendering": wandb.Video("disparity_video.mp4", fps=30, format="gif")})
Let us now look at the results demonstrated in the following panel...

RGB and Depthmap Reconstructions
1


Comparison With and Without Positional Encoding

As we have discussed earlier, without positional encoding, it becomes difficult for the NeRFModel to learn an intrinsic representation of the scene. Let us compare the learned reconstruction with and without positional encoding.

Comparison with and without Positional Encoding
2




Acknowledgments and Further Resources




What We Covered

In this report, we covered the following topics:
  • A basic overview of programming with Jax and Flax for Google Cloud TPUs.
  • Using Neural Radiance Fields for representing 3D scenes as a function of mapping spatial locations and viewing directions to the radiance of the point.
  • Using Fourier Features obtained from the data to let Neural Networks better learn the intrinsic representations.
  • Synthesizing 3D views from a pre-trained Neural Radiance Field using Ray Tracing and Volume Rendering.
  • Using Weights & Biases to track our experiments, compare results, ensure reproducibility, and track utilization of the TPU during our experiment.



Similar Posts


Iterate on AI agents and models faster. Try Weights & Biases today.