Training Neural Radiance Field (NeRF) Models with Keras/TensorFlow and DeepVision

Neural Radiance Fields, colloquially known as NeRFs have struck the world by storm in 2020, released alongside the paper "NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis", and are still the cornerstone of high quality synthesis of novel views, given sparse images and camera positions.

Since then, they've found numerous applications, but probably most prominently in geospatial volumetric modeling, with companies like Google relying on NeRFs to create 3D structures of buildings and heritage sites from various angles of sattelite imagery, and companies specializing in performing 3D reconstruction and digitization of well known cultural sites.

In this guide, we'll be training a Neural Radiance Field (NeRF) model on the original Tiny NeRF dataset, using TensorFlow/Keras and DeepVision, to perform novel view synthesis/3D reconstruction.

In a single hour, on a commercial machine, you'll render novel views of images from the TinyNeRF dataset:

Novel View Synthesis and Neural Radiance Fields

This section provides a simplified summary/introduction to the way Neural Radiance Fields work, but it may take some time to truly intuitively digest how they work if you're new to the field.

Note: The original paper as well as educational video and graphics associated with it are great learning materials. If you're interested in understand the underlying concept of radiance fields that NeRFs rely on to represent a scene, the Wikipedia entry for "light fields" provides a great introduction, but they can be summarized in a high-level fashion as

"The light field is a vector function that describes the amount of light flowing in every direction through every point in space".

NeRFs are used for novel view synthesis - creating new views of objects and images, given some views. In effect, you can think of novel view synthesis as 2D->3D conversion, and many approaches to solve this problem exist, some more successful than others.

Historically a challenging problem, the solution proposed by NeRFs is exceedingly simple yet yields state of the art results, generating very high quality images from novel angles:

This, naturally, positioned them as a foundational approach to solving novel view synthesis, with many subsequent papers exploring, adjusting and improving on the ideas present therein.

Advice: The website released alongside the paper contains an amazing showcase of the method and its results, and an educational video that builds a good intuition for how these networks work has been released officially.

The pipeline from data to results can be summarized as:

Where the neural network learns from sparse images with synthetically generated rays that are projected and sampled at regular intervals. The images are positioned in space given the metadata about the images, such as the camera positions when the images were taken. Because of this - you can't just input any images, and require camera positions to be able to accurately position the images in space for the rays to create a comprehendable set of points. The sampled points then form a 3D set of points that represent the volumetric scene:

The neural network approximates a volumetric scene function - the RGB values and density (σ) of a scene. In effect, we train the network to memorize the color and density of each input point, in order to be able to reconstruct the images from novel view points. That being said - NeRFs aren't trained on a set of images and can extrapolate to new ones. NeRFs are trained to encode a scene, and are then only used for that one scene, as the weights of the network itself represent the scene.

This is the main "drawback" of NeRFs - you have to train a network for each scene you want to encode, and the training process is both somewhat slow and requires lots of memory for large inputs. Improvements in training time are an area of research, with novel techniques such as "Direct Voxel Grid Optimization" that significantly improve training time without trading off image quality in the process.

Neural Radiance Fields in DeepVision and TensorFlow

NeRF implementations can be a bit daunting for those new to volumetric rendering, and the code repositories typically include many helper methods for dealing with volumetric data, which may look unintuitive to some. DeepVision is a novel computer vision library that aims to unify computer vision under a common API, with interchangeable backends (TensorFlow and PyTorch), automatic weight conversions between models, and models with identical implementations across backend frameworks.

To lower the barrier to entry, DeepVision offers a simple yet true-to-the-original implementation of Neural Radiance Field models, with multiple setups to accomodate more and less powerful machines with varying hardware setups:

  • NeRFTiny
  • NeRFSmall
  • NeRFMedium
  • NeRF
  • NeRFLarge

Two parameters are used to create these setups - width and depth. Since NeRFs are, in essence, just an MLP model consisting of tf.keras.layers.Dense() layers (with a single concatenation between layers), the depth directly represents the number of Dense layers, while width represents the number of units used in each one.

NeRF corresponds to the setup used in the original paper, but it may be difficult to run on some local machines, in which case, NeRFMedium provides very similar performance with smaller memory requirements.

Let's go ahead and install DeepVision with pip:

$ pip install deepvision-toolkit

Instantiating a model is as easy as:

import deepvision
model = deepvision.models.NeRFMedium(input_shape=(num_pos, input_features), 
                                     backend='tensorflow')         
model.summary()

The model itself is exceedingly simple:

Model: "ne_rftf"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(None, 640000, 195  0           []                               
                                )]                                                                
                                                                                                  
 dense (Dense)                  (None, 640000, 128)  25088       ['input_1[0][0]']                
                                                                                                  
 dense_1 (Dense)                (None, 640000, 128)  16512       ['dense[0][0]']                  
                                                                                                  
 dense_2 (Dense)                (None, 640000, 128)  16512       ['dense_1[0][0]']                
                                                                                                  
 dense_3 (Dense)                (None, 640000, 128)  16512       ['dense_2[0][0]']                
                                                                                                  
 dense_4 (Dense)                (None, 640000, 128)  16512       ['dense_3[0][0]']                
                                                                                                  
 concatenate (Concatenate)      (None, 640000, 323)  0           ['dense_4[0][0]',                
                                                                  'input_1[0][0]']                
                                                                                                  
 dense_5 (Dense)                (None, 640000, 128)  41472       ['concatenate[0][0]']            
                                                                                                  
 dense_6 (Dense)                (None, 640000, 4)    516         ['dense_5[0][0]']                
                                                                                                  
==================================================================================================
Total params: 133,128
Trainable params: 133,124
Non-trainable params: 4
__________________________________________________________________________________________________
Free eBook: Git Essentials

Check out our hands-on, practical guide to learning Git, with best-practices, industry-accepted standards, and included cheat sheet. Stop Googling Git commands and actually learn it!

We'll take a closer look at how to deal with the outputs of the model and how to render the images produced by the weights of the model, in a moment.

Loading the TinyNeRF Dataset

Since NeRFs can be somewhat expensive to train on larger input images, they were released with a small dataset of 100x100 images, dubbed TinyNeRF to make testing and iterations easier. It has subsequently become a classic dataset to try out NeRFs on and for entering the field, similar to how MNIST became the "Hello World" of digit recognition.

The dataset is available as an .npz file, and contains images, focal points (used for normalization) and camera poses, and can be obtained from the official code release:

import requests
import numpy as np
import matplotlib.pyplot as plt

url = "https://people.eecs.berkeley.edu/~bmild/nerf/tiny_nerf_data.npz"
save_path = 'tiny_nerf.npz'

file_data = requests.get(url).content
with open(save_path, "wb") as file:
    file.write(file_data)

data = np.load(save_path)

images, poses, focal = data["images"], data["poses"], data["focal"]

print(images.shape) # (106, 100, 100, 3)
print(poses.shape) # (106, 4, 4)
print(focal) # array(138.8888789)

There are 106 images, 100x100 each, with 3 channels (RGB). All of the images are of a small lego bulldozer. Let's plot the first five images:

fig, ax = plt.subplots(1, 5, figsize=(20, 12))
for i in range(5):
  ax[i].imshow(images[i])

The camera positions supplied in the dataset are crucial for being able to reconstruct the space in which the images were taken, which allows us to project rays through the images and form a volumetric space with the sampled points on each projection.

However, since this dataset requires lots of preparation for the training phase - DeepVision offers a load_tiny_nerf() dataset loader, that'll perform the preparation for you, with an optional validation_split, pos_embed and num_ray_samples, and returns a vanilla tf.data.Dataset that you can create high-performance pipelines with:

import deepvision

train_ds, valid_ds = deepvision.datasets.load_tiny_nerf(pos_embed=16,
                                                        num_ray_samples=32,
                                                        save_path='tiny_nerf.npz',
                                                        validation_split=0.2,
                                                        backend='tensorflow') 

You absolutely don't need to create a validation set here, since the point is to fully overfit and memorize the images, and the validation set here is created primarily as a sanity check.

Let's take a look at the length and input shapes in the training dataset:

print('Train dataset length:', len(train_ds))
print(train_ds)

This results in:

Train dataset length: 84
<ZipDataset element_spec=(TensorSpec(shape=(100, 100, 3), dtype=tf.float32, name=None), 
                         (TensorSpec(shape=(320000, 99), dtype=tf.float32, name=None), TensorSpec(shape=(100, 100, 32), dtype=tf.float32, name=None)))>

The pos_embed argument sets the number of positional embeddings used to transform the 5D coordinates (x, y, z and viewing angles Theta and Phi). The positional embeddings were crucial for the network to be able to represent higher frequency functions, which was a "missing ingridient" in making this sort of technique work in the past, since networks struggled to approximate functions representing high-frequency variation in color and geometry, due to their bias towards learning low-frequency functions instead:

The num_ray_samples represents the number of samples taken along the length of each ray projected in the image.

Naturally, the more positional embeddings and ray samples you use, the higher the resolution of the volumetric scene you're approximating, and thus, the more detailed the final images will be, at the cost of higher computational costs.

Training a NeRF with TensorFlow/Keras and DeepVision

Let's take a look at an end-to-end example of loading the data, preparing the dataset, instantiating a model and training it using DeepVision and the TensorFlow/Keras ecosystem:

import deepvision
from deepvision.datasets import load_tiny_nerf
import tensorflow as tf

config = {
    'img_height': 100,
    'img_width': 100,
    'pos_embed': 32,
    'num_ray_samples': 64,
    'batch_size': 1
}

num_pos = config['img_height'] * config['img_width'] * config['num_ray_samples']
input_features = 6 * config['pos_embed'] + 3

train_ds, valid_ds = load_tiny_nerf(pos_embed=config['pos_embed'],
                                    num_ray_samples=config['num_ray_samples'],
                                    save_path='tiny_nerf.npz',
                                    validation_split=0.2,
                                    backend='tensorflow') 

train_ds = train_ds.batch(config['batch_size']).prefetch(tf.data.AUTOTUNE)
valid_ds = valid_ds.batch(config['batch_size']).prefetch(tf.data.AUTOTUNE)

model = deepvision.models.NeRFMedium(input_shape=(num_pos, input_features),
                                     backend='tensorflow')

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), 
              loss=tf.keras.losses.MeanSquaredError())
              
callbacks = [tf.keras.callbacks.ReduceLROnPlateau()]

history = model.fit(train_ds, 
                    epochs=50, 
                    validation_data=valid_ds, 
                    callbacks=callbacks)

On an Nvidia GTX1660Super, training with 32 positional embeddings and 64 ray samples takes ~1min per epoch, but smaller setups, such as 8-16 positional embeddings and 32 ray samples may take as little as ~7s per epoch:

Epoch 1/50
84/84 [==============================] - 65s 746ms/step - loss: 0.0603 - psnr: 12.6432 - val_loss: 0.0455 - val_psnr: 13.7601 - lr: 0.0010
...
Epoch 50/50
84/84 [==============================] - 55s 658ms/step - loss: 0.0039 - psnr: 24.1984 - val_loss: 0.0043 - val_psnr: 23.8576 - lr: 0.0010

After roughly a single hour, on a single commercial GPU, the model achieves ~24 PSNR. The thing with NeRFs is - the longer you train, the closer it'll get to representations of the original images, meaning, you'll typically see the metrics increasing through time as you train more. It does help to have a ReduceLROnPlateau callback to handle learning rate reduction to fine tune the results nearing the end of training.

The model reports two metrics - loss and psnr. The loss is the mean squared error for each pixel, and works as a great loss function for NeRFs, but is difficult to interpret.

Peak Signal-to-Noise Ratio (PSNR) is the ratio between the signal (maximum power of a signal) and the noise (power of the noise that corrupts the fidelity of the signal) which degrades the image. Peak Signal-to-Noise Ratio can be used as an image quality metric, and is very intuitive to interpret for humans.

Already at a PSNR of 24, images become fairly clear, and NeRFs can reach PSNRs of over 40 on TinyNeRF given enough training time.

Visualizing Outputs

The network outputs a tensor of shape [batch_size, 640000, 4] where the channels represent RGB and density, and the 640000 points encode the scene. To represent these as images, we'll want to reshape the tensor to a shape of (batch_size, img_height, img_width, num_ray_samples, 4), and then disect the 4 channels into RGB and sigma and process them into an image (and optionally, a depth/accuracy map).

Specifically, the RGB channels are passed through a sigmoid activation, while the sigma channel is passed through a ReLU activation, before being processed further and reduced to a tensor of shape (batch_size, img_height, img_width, rgb_channels), and two tensors of shape (batch_size, img_height, img_width, depth_channel) and (batch_size, img_height, img_width, accuracy).

To make this process easier, we can use the nerf_render_image_and_depth_tf() function from volumetric_utils, which accepts the model to predict RGB and sigma from inputs, and returns a batch of images, depth maps and accuracy maps:

import matplotlib.pyplot as plt
from deepvision.models.volumetric.volumetric_utils import nerf_render_image_and_depth_tf

for batch in train_ds.take(5):
    (images, rays) = batch
    (rays_flat, t_vals) = rays
    
    image_batch, depth_maps, _ = nerf_render_image_and_depth_tf(model=model, 
                                         rays_flat=rays_flat, 
                                         t_vals=t_vals,
                                         img_height=config['img_height'], 
                                         img_width=config['img_width'], 
                                         num_ray_samples=config['num_ray_samples'])
    fig, ax = plt.subplots(1, 2)
    ax[0].imshow(tf.squeeze(image_batch[0]))
    ax[1].imshow(tf.squeeze(depth_maps[0]))

Here, we're plotting 5 batches (each with one image) and their depth maps.
During training, the model itself relies on the nerf_render_image_and_depth_tf() function to convert predictions to images and calculate mean squared error and PSNR for the results. Running this code results in:

Conclusion

In this guide - we've summarized some of the key elements of Neural Radiance Fields, as a brief introduction to the subject, followed by loading and preparing the TinyNeRF Dataset in TensorFlow, using tf.data, and training a NeRF model with the Keras and DeepVision ecosystems.

Was this article helpful?

Improve your dev skills!

Get tutorials, guides, and dev jobs in your inbox.

No spam ever. Unsubscribe at any time. Read our Privacy Policy.

David LandupAuthor

Entrepreneur, Software and Machine Learning Engineer, with a deep fascination towards the application of Computation and Deep Learning in Life Sciences (Bioinformatics, Drug Discovery, Genomics), Neuroscience (Computational Neuroscience), robotics and BCIs.

Great passion for accessible education and promotion of reason, science, humanism, and progress.

© 2013-2025 Stack Abuse. All rights reserved.

AboutDisclosurePrivacyTerms