Implementing NeRF in JAX
This article uses JAX to create a minimal implementation of 3D volumetric rendering of scenes represented by Neural Radiance Fields, using W&B to track all metrics.
Created on April 7|Last edited on February 15
Comment
In this article, we attempt to create a minimal implementation of 3D volumetric rendering of scenes represented by Neural Radiance Fields (NeRF)using JAX. The ideas used in this implementation were proposed by the paper NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis, where the authors propose a method that achieves state-of-the-art results for synthesizing novel views of complex scenes by optimizing an underlying continuous volumetric scene function using a sparse set of input views.
We train a model to learn the NeRF representation of a simple 3D scene using JAX and Flax on Google Cloud TPUs.
First, let's look at what we'll be covering:
Table of Contents
What is JAX?What is Flax?Jumping into the CodeSetting up the TPU for JAXSetting Up Our Weights & Biases RunWhat is a NeRF (Neural Radiance Field)?Implementing the Neural Network in FlaxTraining The ModelRendering The SceneAcknowledgments and Further ResourcesWhat We CoveredSimilar Posts
What is JAX?
JAX is an accelerated computation framework that brings together Autograd and XLA for high-performance machine learning research. It provides a simple NumPy and SciPy-like interface for writing code, compiled by the XLA Compiler to run on CPU, GPU, and TPU.

JAX also supports the just-in-time (JIT) compilation of Python functions into XLA-optimized kernels using a one-function API. Due to its functional programming paradigm, Jax allows us to use composable transforms to transform a function without modifying it.
What is Flax?
Flax is a neural network library and ecosystem for JAX, designed for flexibility. It provides us with a neural network API flax.linen that makes it easy for us to create models in a pythonic manner.
Additionally, Flax provides utilities and patterns for replicated training, serialization and checkpointing, metrics, and prefetching on the device.

Jumping into the Code
The code that we'll demonstrate in this report works with Google Cloud TPUs. You can run this code for free on Google Colab and Kaggle by setting the accelerator to TPU in both cases. Here are those links:
Dependencies
!pip3 install -q -U jax jaxlib flax optax imageio-ffmpeg wandb
Other than JAX, Jaxlib, and Flax, we also install the following dependencies:
- Optax: a gradient processing and optimization library for JAX. Optax is designed to facilitate research by providing building blocks that can be easily recombined in custom ways.
- Weights & Biases for tracking all metrics and losses of our experiments in real-time and creating rich, interactive dashboards.
Now, let's import all the necessary libraries:
import osimport timeimport wandbimport imageioimport requestsfrom typing import Anyimport ipywidgets as widgetsfrom functools import partialfrom tqdm.notebook import tqdmfrom kaggle_secrets import UserSecretsClientimport jaximport flaximport optaxfrom jax import laximport flax.linen as nnfrom flax.training import train_state, common_utilsimport numpy as npimport jax.numpy as jnpfrom base64 import b64encodefrom IPython.display import HTMLimport plotly.express as pxfrom plotly.subplots import make_subplotsimport matplotlib.pyplot as pltfrom mpl_toolkits.axes_grid1 import ImageGrid
Setting up the TPU for JAX
Tensor Processing Units (or TPUs) are hardware accelerators specializing in deep learning tasks. They were created by Google and have been behind many cutting-edge results in machine learning research. For a detailed guide to using TPUs on Kaggle Notebooks, you can refer to https://www.kaggle.com/docs/tpu.
Let's set them up:
# Detect if Kaggle Notebook has access to TPUs or notif 'TPU_NAME' in os.environ:import requestsif 'TPU_DRIVER_MODE' not in globals():url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'resp = requests.post(url)TPU_DRIVER_MODE = 1from jax.config import configconfig.FLAGS.jax_xla_backend = "tpu_driver"config.FLAGS.jax_backend_target = os.environ['TPU_NAME']print("TPU DETECTED!")print('Registered TPU:', config.FLAGS.jax_backend_target)# Detect if Google Colab Notebook has access to TPUs or notelif "COLAB_TPU_ADDR" in os.environ:import jax.tools.colab_tpujax.tools.colab_tpu.setup_tpu()else:print('No TPU detected.')DEVICE_COUNT = len(jax.local_devices())TPU = DEVICE_COUNT==8if TPU:print("8 cores of TPU ( Local devices in Jax ):")print('\n'.join(map(str,jax.local_devices())))
Setting Up Our Weights & Biases Run
We can now call wandb.init to initialize a new job. This creates a new run in Weights & Biases and launches a background process to sync data. We sync all the configs of our experiments with the W&B run, which makes it far easier for us to reproduce the results of the experiment later.
wandb.login()wandb.init(project="nerf-jax", job_type="train")# sync all experiment configs with Weights & Biasesconfig = wandb.configconfig.near_bound = 2. # Near Bound of sample space for 3d pointsconfig.far_bound = 6. # Far Bound of sample space for 3d pointsconfig.batch_size = int(1e4) # Batch Sizeconfig.num_sample_points = 256 # Number of points to be sampled across the volumeconfig.epsilon = 1e10 # Hyperparameter for volume renderingconfig.apply_positional_encoding = True # Apply posittional encoding to the input or notconfig.positional_encoding_dims = 6 # Number of positional encodings appliedconfig.num_dense_layers = 8 # Number of dense layers in MLPconfig.dense_layer_width = 256 # Dimentionality of dense layers' output spaceconfig.learning_rate = 5e-4 # Learning Rateconfig.train_epochs = 1000 # Number of training epochsconfig.plot_interval = 100 # Epoch interval for plotting results during training
We'll discuss what each of these configs represents in the later sections of this report.
What is a NeRF (Neural Radiance Field)?
A NeRF (Neural Radiance Field) is a neural network that non-discretely models a point and viewing direction in 3D space to the amount of light emitted by that point in each direction. It allows us to synthesize novel (new) views of complex scenes.
To clarify, let's break it down into its component parts:
- The word neural obviously means that there's a Neural Network involved
- Radiance refers to the radiance of the scene that the Neural Network outputs. It is basically describing how much light is being emitted by a point in space in each direction, and
- The word Field means that the Neural Network models a continuous and non-discretized representation of the scene it learns.

Figure: Overview of the Model
- The Neural Network in this case, is a simple Multi-layered Perceptron or MLP with ReLU activation.
- The MLP consists of 9 fully-connected layers of width 256.
- The Input to the MLP consists of 2 components:
- which denotes the spatial position of a given point in 3D space.
- denotes a given viewing direction from the point.
- forms a single continuous 5D coordinate which is fed to the MLP.
- The output to the MLP consists of 2 components as well:
- which denotes the view composed from the point along the direction in RGB colorspace.
- denotes the density or transparency of the point.
- The value of lies in the range .
- A value of means there is nothing at the point or the point is transparent and A value of means that the point is opaque.
Implementing the Neural Network in Flax
We'll use the flax.linen API to create the neural network . So what is linen exactly? Simple: linen is a neural network API developed as part of the Flax package. It builds on a functional core, enabling direct usage of JAX transformations such as vmap, remat, or scanning inside our modules.
We'll create our NeRFModel as a subclass of the flas.linen.Module similar to how we would use torch.nn.Module API for PyTorch or tf.keras.Model API for TensorFlow.
class NeRFModel(nn.Module):# dtype of the computationdtype: Any = jnp.float32# numerical precision of the computation (set to bfloat16 by default)precision: Any = lax.Precision.DEFAULT# Apply positional ecoding or notapply_positional_encoding: bool = config.apply_positional_encoding@nn.compactdef __call__(self, input_points):# Apply positional encoding to the input pointsx = positional_encoding(input_points) if self.apply_positional_encoding else input_pointsfor i in range(config.num_dense_layers):# Fully-connected layerx = nn.Dense(config.dense_layer_width,dtype=self.dtype,precision=self.precision)(x)# ReLU activation functionx = nn.relu(x)# Skip connectionx = jnp.concatenate([x, input_points], axis=-1) if i == 4 else x# Output consists of 4 values: (r, g, b, sigma)x = nn.Dense(4, dtype=self.dtype, precision=self.precision)(x)return x
Positional Encodings
Let's step back from our NeRFModel and consider a simpler problem to act as a proxy for NeRF:
Make a neural network memorize a single image.
The model takes the input of a set of pixel coordinates and predicts the value of the pixel. We train this model by overfitting it on a single image with the object of memorizing that single image accurately. Note that we consider the model to be an MLP with a ReLU activation function similar to our NeRFModel.
The surprising thing is that such a model contains 8 to 10 times more parameters than pixels in the image, but it still does a rather poor job, even when given an extremely generous amount of training time.

It's almost as if we can see the ReLU activation in the attempted reconstruction 😱
What's The Solution?
Thankfully, the NeRF authors found a solution: they observed that if we take the sinusoids of the input coordinates with increasing frequency and expand the feature vector out into a high-dimensional space and feed all the concatenated feature vectors into the MLP instead of the coordinates directly, the model is able to memorize the image!!!
Although the idea of Fourier Feature Mapping was presented in the original NeRF paper, we're focusing on. The authors didn't provide any explanation as to why it worked. The explanation was provided in a later work titled Fourier Features Let Networks Learn High-Frequency Functions in Low-Dimensional Domains.
In fact, the explanation is right there in the title. Basically, passing input points through a simple Fourier Feature Mapping enables an MLP to learn high-frequency functions (such as an RGB image) in low-dimensional problem domains (such as a 2D coordinate of pixels).
This pre-processing step of encoding the input is called Fourier Feature Mapping or Positional Encoding. This encoding can be represented as
or,

As we can see that the MLP with Fourier features is able to memorize the image quickly 🚀
Let's check what the Fourier Feature Mapping looks like in code:
def positional_encoding(inputs):batch_size, _ = inputs.shape# Applying vmap transform to vectorize the multiplication operationinputs_freq = jax.vmap(lambda x: inputs * 2.0 ** x)(jnp.arange(config.positional_encoding_dims))periodic_fns = jnp.stack([jnp.sin(inputs_freq), jnp.cos(inputs_freq)])periodic_fns = periodic_fns.swapaxes(0, 2).reshape([batch_size, -1])periodic_fns = jnp.concatenate([inputs, periodic_fns], axis=-1)return periodic_fns
Note that we apply the composable transform jax.vmap on the multiplication operation, which "vectorizes" the operation. What this means is that it allows us to compute the output of a function in parallel over some axis of the input.
View Synthesis
Before we get into discussing view synthesis, let's grab our dataset and take a look at it.
We'll be using the Tiny-NeRF dataset, which is a subset of the original Blender dataset that was used by the authors. We would use the first 100 images and their respective poses as training data. For validation, we use a single image, pose pair.
The small validation subset is not going to be an issue in our case since we essentially want our network to "memorize" the 3D scene given by the data in a way that enables smooth interpolation between frames.
!wget https://people.eecs.berkeley.edu/~bmild/nerf/tiny_nerf_data.npzdata = np.load("tiny_nerf_data.npz")images = data["images"]poses = data["poses"]focal = float(data["focal"])_, image_height, image_width, _ = images.shape# We would use the first 100 images, poses as training datatrain_images, train_poses = images[:100], poses[:100]# We use a single image, pose pair for validationval_image, val_pose = images[101], poses[101]# Visualize the training datafig = plt.figure(figsize=(16, 16))grid = ImageGrid(fig, 111, nrows_ncols=(4, 4), axes_pad=0.1)random_images = images[np.random.choice(np.arange(images.shape[0]), 16)]for ax, image in zip(grid, random_images):ax.imshow(image)plt.title("Sample Images from Tiny-NeRF Data")plt.show()

The dataset basically consists of a set of images of this lego model taken from multiple poses, now the task at hand is to learn a representation of this object such that we can synthesize the view of the object from any given pose, even ones that are not present in the dataset.

An example of 3D images collected from various camera poses along a spherical surface.
View Synthesis using Volume Rendering
Now that we have a dense sampling of views represented by an infinitely high-resolution field of radiance that's defined by the MLP NeRFModel, photorealistic novel views can be reconstructed by simple light field sample interpolation techniques.
For synthesizing novel views, the authors of NeRF propose using volumetric representations to address the task of high-quality photorealistic view synthesis from a set of input RGB images. Approaches are able to realistically represent complex shapes and materials, are well-suited for gradient-based optimization, and tend to produce less visually distracting artifacts than mesh-based methods.
We would render the color of any ray passing through the scene using principles from classical volume rendering. The expected color of a ray from the camera given by with near bound and far bound is given by
where...
- is the volume density predicted by the MLP that can be interpreted as the differential probability of a ray terminating at an infinitesimal particle at the location . It is basically the opacity of the given point.
- is the RGB color predicted by the MLP that can be interpreted as the radiance of a ray terminating at the location .
- denotes the accumulated transmittance along the ray from near bound to far bound i.e., the probability that the ray travels from to t without hitting any other particle.

In order to perform volume rendering, we have to perform the following steps:
Generate Rays: Given a pose, we have to generate a grid of rays corresponding to the dimensions of the image that we desire to render. This is done by the following function:
def generate_rays(height, width, focal, pose):# Create a 2D rectangular grid for the rays corresponding to image dimensionsi, j = np.meshgrid(np.arange(width), np.arange(height), indexing="xy")transformed_i = (i - width * 0.5) / focal # Normalize the x-axis coordinatestransformed_j = -(j - height * 0.5) / focal # Normalize the y-axis coordinatesk = -np.ones_like(i) # z-axis coordinates# Create the unit vectors corresponding to ray directionsdirections = np.stack([transformed_i, transformed_j, k], axis=-1)# Compute Origins and Directions for each raycamera_directions = directions[..., None, :] * pose[:3, :3]ray_directions = np.einsum("ijl,kl", directions, pose[:3, :3])ray_origins = np.broadcast_to(pose[:3, -1], ray_directions.shape)return np.stack([ray_origins, ray_directions])
Compute 3D query points between the near bound and far bound . We would use a stratified sampling approach where we partition into evenly-spaced bins and then draw one sample uniformly at random from within each bin. This action is performed by the following function:
def compute_3d_points(ray_origins, ray_directions, random_number_generator=None):"""Compute 3d query points for volumetric rendering"""# Sample space to parametrically compute the ray pointst_vals = np.linspace(config.near_bound, config.far_bound, config.num_sample_points)if random_number_generator is not None:# inject a uniform noise into the sample space to make it continuoust_shape = ray_origins.shape[:-1] + (config.num_sample_points,)noise = jax.random.uniform(random_number_generator, t_shape) * (config.far_bound - config.near_bound) / config.num_sample_pointst_vals = t_vals + noise# Compute the ray traversal points using r(t) = o + t * dray_origins = ray_origins[..., None, :]ray_directions = ray_directions[..., None, :]t_vals_flat = t_vals[..., :, None]points = ray_origins + ray_directions * t_vals_flatreturn points, t_vals
Note that is given by config.num_sample_points in our code.
💡
Compute the Radiance Field: We get the colors and opacities from the NeRFModel. This action is performed by the following function:
def compute_radiance_field(model, points):"""Compute Radiance Field"""# Perform forward propagationmodel_output = lax.map(model, jnp.reshape(points, [-1, config.batch_size, 3]))radiance_field = jnp.reshape(model_output, points.shape[:-1] + (4,))# Slice the model outputopacities = nn.relu(radiance_field[..., 3])colors = nn.sigmoid(radiance_field[..., :3])return opacities, colors
Compute Adjacent Distances between adjacent intervals along the sample space. This action is performed by the following functions:
def compute_adjacent_distances(t_vals, ray_directions):"""Get distances between adjacent intervals along sample space"""distances = t_vals[..., 1:] - t_vals[..., :-1]distances = jnp.concatenate([distances, np.broadcast_to([config.epsilon], distances[..., :1].shape)], axis=-1)# Multiply each distance by the norm of its corresponding direction ray# to convert to real world distance (accounts for non-unit directions)distances = distances * jnp.linalg.norm(ray_directions[..., None, :], axis=-1)return distances
Compute Transmittance: The accumulated transmittance along the ray is given by where is the distances between adjacent intervals along the sample space. This action is performed by the following functions:
def compute_weights(opacities, distances):"""Compute weight for the RGB of each sample along each ray"""# Compute density from the opacitydensity = jnp.exp(-opacities * distances)alpha = 1.0 - densityclipped_difference = jnp.clip(1.0 - alpha, 1e-10, 1.0)# A cumulative product is basically used to express the idea# of the ray not having reflected up to this sample yettransmittance = jnp.cumprod(jnp.concatenate([jnp.ones_like(clipped_difference[..., :1]),clipped_difference[..., :-1]], -1), axis=-1)return alpha * transmittance
Now, all we need to do is bring all the steps together to perform volume rendering and synthesize the desired view.
def perform_volume_rendering(model, ray_origins, ray_directions, random_number_generator=None):# Compute 3d query pointspoints, t_vals = compute_3d_points(ray_origins, ray_directions, random_number_generator)# Get color and opacities from the modelopacities, colors = compute_radiance_field(model, points)# Get distances between adjacent intervals along sample spacedistances = compute_adjacent_distances(t_vals, ray_directions)# Compute weight for the RGB of each sample along each rayweights = compute_weights(opacities, distances)# Compute weighted RGB color of each sample along each rayrgb_map = jnp.sum(weights[..., None] * colors, axis=-2)# Compute the estimated depth mapdepth_map = jnp.sum(weights * t_vals, axis=-1)# Sum of weights along each ray; the value is in [0, 1] up to numerical erroracc_map = jnp.sum(weights, axis=-1)# Disparity map is basically the inverse of depthdisparity_map = 1. / jnp.maximum(1e-10, depth_map / jnp.sum(weights, axis=-1))return rgb_map, depth_map, acc_map, disparity_map, opacities
The Tiny NeRF dataset doesn not include view directions, hence the inputs used in this report are 3D instead of 5D as presented in the paper.
💡
Training The Model
Initializing The Model
First, the code.
def initialize_model(key, input_pts_shape):# Create an instance of the modelmodel = NeRFModel()# Initialize the model parametersinitial_params = jax.jit(model.init)({"params": key},jnp.ones(input_pts_shape),)return model, initial_params["params"]
Note that we use the jax.jit transform on the model initialization operation for the model to be compiled to the XLA compile in time. This makes it more effective to run on an accelerator which is a TPU in our case. Compiling a function also avoids the overhead of the Python interpreter, which gives you a speedup irrespective of the accelerator we use.
Implementing The Train and Validation Steps
The train and validation steps are stateless functions that perform a single iteration of training and validation each.
The Train Step
- We write the train_step as a function of:
- the train state, which is a combination of the model state and the optimizer state
- the input rays
- the target images
- We use Mean-squared Error as the loss function.
Let's take a look at the train_step function...
def train_step(state, batch, rng):"""Train Step"""# Unravel the inputs and targets from the batchinputs, targets = batch# Compute the loss in a stateless mannerdef loss_fn(params):# Create the model function from the train statemodel_fn = lambda x: state.apply_fn({"params": params}, x)# Unravel the input raysray_origins, ray_directions = inputs# Get the RGB view by performing volume renderingrgb, *_ = perform_volume_rendering(model_fn, ray_origins, ray_directions, rng)# Compute mean-squared errorreturn jnp.mean((rgb - targets) ** 2)# Transform the loss function to get the loss value and the gradientstrain_loss, gradients = jax.value_and_grad(loss_fn)(state.params)# Compute all-reduce mean on gradients over the pmapped axisgradients = lax.pmean(gradients, axis_name="batch")# Updated the model params and the optimizer statenew_state = state.apply_gradients(grads=gradients)# Mean of train loss of the batchtrain_loss = jnp.mean(train_loss)# Compute PSNRtrain_psnr = -10.0 * jnp.log(train_loss) / jnp.log(10.0)return train_loss, train_psnr, new_state# Apply the transform jax.pmap on the train_step to parallelize it on XLA devicesparallelized_train_step = jax.pmap(train_step, axis_name="batch")
Note that the transform pmap is different from vmap regarding how they perform the computation. While vmap vectorizes a function by adding a batch dimension to every primitive operation in the function, pmap replicates the function and executes each replica on its own XLA device in parallel.
💡
The Validation Step
The validation_step is almost the same as the train_step function, with just two differences:
- There would be no need to compute the gradients and no change in the model and optimizer state in the validation_step.
@jax.jitdef validation_step(state):"""Validation Step"""# Create the model function from the statemodel_fn = lambda x: state.apply_fn({"params": state.params}, x)# Unravel the validation rays, which is a globalray_origins, ray_directions = val_rays# Get the rendered views by performing volume renderingrgb, depth, *_ = perform_volume_rendering(model_fn, ray_origins, ray_directions)# Compute Mean-squared error with the validation image, which is a globalloss = jnp.mean((rgb - val_image) ** 2)# Compute the PSNRpsnr = -10.0 * jnp.log(loss) / jnp.log(10.0)return rgb, depth, psnr, loss
Note that function transforms are simply callables that return a transformed function, hence we can use then as decorators as well.
💡
The Training and Validation Loop
Now let us bring everything together and write the training loop...
# Create the train and validation rays as globalstrain_rays = np.stack(list(map(lambda x: generate_rays(image_height, image_width, focal, x), train_poses)))val_rays = generate_rays(image_height, image_width, focal, val_pose)# Number of acceleratorsn_devices = jax.local_device_count()# Random Number Generatorkey, rng = jax.random.split(jax.random.PRNGKey(0))# Initialize the Modelmodel, params = initialize_model(key, (image_height * image_width, 3))# Define the Optimizeroptimizer = optax.adam(learning_rate=config.learning_rate)# Create the Training Statestate = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)# Transfer arrays in the state to the specified devices and form ShardedDeviceArraysstate = jax.device_put_replicated(state, jax.local_devices())# Function for executing the train and validation loopdef train_and_evaluate(state, train_step_fn, validation_step_fn):train_loss_history, train_psnr_history = [], []val_loss_history, val_psnr_history = [], []for epoch in tqdm(range(config.train_epochs)):# Shard Random Number Generatorsrng_index, rng_epoch = jax.random.split(jax.random.fold_in(rng, epoch))sharded_rngs = common_utils.shard_prng_key(rng_epoch)# Create the Train Batchtrain_index = jax.random.randint(rng_index, (n_devices,), minval=0, maxval=len(train_rays))train_batch = train_rays[tuple(train_index), ...], train_images[tuple(train_index), ...]# Perform the Training Steptrain_loss, train_psnr, state = train_step_fn(state, train_batch, sharded_rngs)train_loss_history.append(np.asarray(np.mean(train_loss)))train_psnr_history.append(np.asarray(np.mean(train_psnr)))wandb.log({"Train Loss": np.asarray(np.mean(train_loss))}, step=epoch)wandb.log({"Train PSNR": np.asarray(np.mean(train_loss))}, step=epoch)# Perform the Validation Stepvalidation_state = flax.jax_utils.unreplicate(state)rgb, depth, val_psnr, val_loss = validation_step_fn(validation_state)val_loss_history.append(np.asarray(val_loss))val_psnr_history.append(np.asarray(val_psnr))wandb.log({"Validation Loss": np.asarray(val_loss)}, step=epoch)wandb.log({"Validation PSNR": np.asarray(val_psnr)}, step=epoch)# Plot the result every plot intervalif epoch % config.plot_interval == 0:fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))ax1.imshow(rgb)ax1.set_title(f"Predicted RGB at Epoch {epoch}")ax1.axis("off")ax2.imshow(depth)ax2.set_title(f"Predicted Depth at Epoch {epoch}")ax2.axis("off")plt.show()inference_state = flax.jax_utils.unreplicate(state)history = {"train_loss": train_loss_history,"train_psnr": train_psnr_history,"val_loss": val_loss_history,"val_psnr": val_psnr_history}return state, inference_state, history
Note that we use jax.device_put_replicated to transfer the arrays in the state on multiple TPU devices and form Sharded Device Arrays. Every array in the pytree is replicated into the devices which makes the same data replicated across all local devices.
💡
Now, let us perform the training...
%%wandbstate, inference_state, history = train_and_evaluate(state, parallelized_train_step, validation_step)
After running wandb.init(), if we use %%wandb at the beginning of a new cell in our notebook, we can track the runs live from our notebook. For more information on tracking Weight & Biases runs from jupyter notebook, you can refer to the official guide.
💡
Experiment Tracking
3
Rendering The Scene
Now that our model has learned the representation of our model, we should be able to reconstruct the scene from any given pose.
# Translating Matrix for movement in tdef get_translation_matrix(t):return np.asarray([[1, 0, 0, 0],[0, 1, 0, 0],[0, 0, 1, t],[0, 0, 0, 1],])# Rotation Matrix for movement in phidef get_rotation_matrix_phi(phi):return np.asarray([[1, 0, 0, 0],[0, np.cos(phi), -np.sin(phi), 0],[0, np.sin(phi), np.cos(phi), 0],[0, 0, 0, 1],])# Rotation Matrix for moevment in thetadef get_rotation_matrix_theta(theta):return np.asarray([[np.cos(theta), 0, -np.sin(theta), 0],[0, 1, 0, 0],[np.sin(theta), 0, np.cos(theta), 0],[0, 0, 0, 1],])# Create the camera to world coordinate transform matrixdef pose_spherical(theta, phi, radius):camera_to_world_transform = get_translation_matrix(radius)camera_to_world_transform = get_rotation_matrix_phi(phi / 180.0 * np.pi) @ camera_to_world_transformcamera_to_world_transform = get_rotation_matrix_theta(theta / 180.0 * np.pi) @ camera_to_world_transformcamera_to_world_transform = np.array([[-1, 0, 0, 0],[0, 0, 1, 0],[0, 1, 0, 0],[0, 0, 0, 1]]) @ camera_to_world_transformreturn camera_to_world_transform# Function to render the scene from the rays@jax.jitdef get_renderings(rays):model_fn = lambda x: inference_state.apply_fn({"params": inference_state.params}, x)ray_origins, ray_directions = raysrgb, depth, acc, disparity, opacities = perform_volume_rendering(model_fn, ray_origins, ray_directions)rgb = (255 * jnp.clip(rgb, 0, 1)).astype(jnp.uint8)return rgb, depth, acc, disparity, opacities# Create a 360 degree video of the 3D scenedef get_frames():video_angle = jnp.linspace(0.0, 360.0, 120, endpoint=False)camera_to_world_transform = map(lambda th: pose_spherical(th, -30.0, 4.0), video_angle)rays = np.stack(list(map(lambda x: generate_rays(image_height, image_width, focal, x[:3, :4]), camera_to_world_transform)))rgb_frames, depth_frames, acc_maps, disparity_maps, opacities = lax.map(get_renderings, rays)rgb_frames = np.asarray(rgb_frames)depth_frames = np.asarray(depth_frames)acc_maps = np.asarray(acc_maps * 255.)disparity_maps = np.asarray(disparity_maps * 255.)return rgb_frames, depth_frames, acc_maps, disparity_mapsrgb_frames, depth_frames, acc_maps, disparity_maps = get_frames()imageio.mimwrite("rgb_video.mp4", tuple(rgb_frames), fps=30, quality=7)imageio.mimwrite("depth_video.mp4", tuple(depth_frames), fps=30, quality=7)imageio.mimwrite("acc_video.mp4", tuple(acc_maps), fps=30, quality=7)imageio.mimwrite("disparity_video.mp4", tuple(disparity_maps), fps=30, quality=7)wandb.log({"RGB Rendering": wandb.Video("rgb_video.mp4", fps=30, format="gif")})wandb.log({"Depth Rendering": wandb.Video("depth_video.mp4", fps=30, format="gif")})wandb.log({"Accuracy Rendering": wandb.Video("acc_video.mp4", fps=30, format="gif")})wandb.log({"Disparity Map Rendering": wandb.Video("disparity_video.mp4", fps=30, format="gif")})
Let us now look at the results demonstrated in the following panel...
RGB and Depthmap Reconstructions
1
Comparison With and Without Positional Encoding
As we have discussed earlier, without positional encoding, it becomes difficult for the NeRFModel to learn an intrinsic representation of the scene. Let us compare the learned reconstruction with and without positional encoding.
Comparison with and without Positional Encoding
2
Acknowledgments and Further Resources
- The ideas used in this report are based primarily on the papers NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis and Fourier Features Let Networks Learn High-Frequency Functions in Low-Dimensional Domains.
- This work is primarily based on tiny_nerf.ipynb which was used by the authors to demonstrate the concept of NeRF on the tiny_nerf dataset.
- Our work also draws inspiration from 3D volumetric rendering with NeRF on the Keras Docs by Aritra Roy Gosthipaty and Ritwik Raha.
- We also referred to the official Jax implementation of NeRF google-research/jaxnerf and the unofficial implementation by myagues/flax_nerf.
- Several theoretical explanations used in this notebook were derived Jon Barron's talk at MIT on NeRF.
- Several of the figures used in this notebook were inspired by the figures from the respective papers and the official project page of NeRF.
- The snippets for Positional Encoding were derived from the NeurIPS 2020 Spotlight video for the paper.
- Ray Tracing in One Weekend Book Series by Peter Shirley provided us with an excellent primer on Ray Tracing.
What We Covered
In this report, we covered the following topics:
- Using Neural Radiance Fields for representing 3D scenes as a function of mapping spatial locations and viewing directions to the radiance of the point.
- Using Fourier Features obtained from the data to let Neural Networks better learn the intrinsic representations.
- Synthesizing 3D views from a pre-trained Neural Radiance Field using Ray Tracing and Volume Rendering.
- Using Weights & Biases to track our experiments, compare results, ensure reproducibility, and track utilization of the TPU during our experiment.
Similar Posts
NeRF – Representing Scenes as Neural Radiance Fields for View Synthesis
Block-NeRF: Scalable Large Scene Neural View Synthesis
Representing large city-scale environments spanning multiple blocks using Neural Radiance Fields
Creating 3D Meshes with Neural ODEs
Diffeomorphic Genus-0 Mesh Generation using Neural ODEs
3D Image Inpainting With Weights & Biases
In this article, we take a look at a novel way to convert a single RGB-D image into a 3D image, using Weights & Biases to visualize our results.
Add a comment
Iterate on AI agents and models faster. Try Weights & Biases today.