Future frame generation in autonomous driving
This tutorial aims to show the importance of machine learning in the sector of autonomous driving with a specific code example.
Created on December 27|Last edited on January 6
Comment
Autonomous driving refers to vehicles capable of navigating and operating without human intervention. This technology relies on a myriad of sensors, cameras, radar, lidar, and advanced computing systems to interpret the vehicle's surroundings and make real-time decisions.

Table of contents
The journey thus farUnderstanding autonomous driving levelsThe impact on safety and efficiencyThe role of machine learningPerception and sensor fusionObject detection and classificationDecision making and controlMapping and localizationFuture frame prediction in autonomous drivingUnveiling Pix2Pix: A Brief OverviewGenerative Adversarial Networks (GANs):Conditional GANs for tmage translation:DatasetsUsing Weights & BiasesBuilding the Pix2Pix modelGenerator (U-Net with 3D Convs)Discriminator (PatchGAN with 3D Convs)Model trainingResultsChallenges and future directionsReferences
The integration of machine learning algorithms in autonomous driving systems has marked a revolutionary transformation in the automotive industry.
Here's how we got where we are, and a look at where we're going.
The journey thus far
The inception of autonomous driving traces back to the early 20th century when the concept of self-driving vehicles was mere speculation.

Fast forward to recent decades, and we've witnessed remarkable progress in the field, primarily fueled by advancements in artificial intelligence, machine learning, sensor technology, and computing power.
Major automotive and tech giants have invested heavily in research and development, launching prototypes and pilot programs to test autonomous vehicles (AVs) on public roads. Companies like Waymo, Tesla, Uber, and traditional automakers have been at the forefront of this innovation, pushing the boundaries of what's possible.
Understanding autonomous driving levels
The Society of Automotive Engineers (SAE) categorizes autonomous driving into six distinct levels, ranging from Level 0 (no automation) to Level 5 (full automation). These levels provide a standardized framework for assessing the scope and capabilities of self-driving technology, enabling consumers, regulators, and manufacturers to gauge how much human intervention is required at each stage.
For a more detailed explanation of these levels and their implications, check out this in-depth article on SAE automation levels.
- Level 1 and 2: These levels involve driver assistance features like adaptive cruise control and lane-keeping assistance, requiring human supervision.
- Level 3: At this stage, vehicles can handle most driving tasks, but a human driver must be ready to intervene if needed.
- Level 4: Vehicles can operate autonomously in specific conditions and environments without human intervention but might still have limitations.
- Level 5: Full autonomy with no need for human involvement. These vehicles can navigate all conditions and environments without any human input.

The impact on safety and efficiency
Safety stands as a pivotal factor driving the development of autonomous driving technology. Proponents argue that AVs equipped with advanced sensors, cameras, and AI algorithms can potentially reduce accidents caused by human error, which accounts for a significant portion of road accidents today.
Moreover, autonomous vehicles have the potential to optimize traffic flow, reduce congestion, and enhance fuel efficiency through smoother driving patterns and improved route planning.
The role of machine learning
Perception and sensor fusion
Machine learning algorithms play a pivotal role in enabling vehicles to perceive and interpret their environment. They process data from various sensors like cameras, lidar, and radar, and use computer vision techniques to recognize objects, pedestrians, traffic signs, and lane markings. These algorithms aid in sensor fusion, combining information from multiple sources to create a comprehensive understanding of the vehicle's surroundings.
Object detection and classification
One of the core aspects of autonomous driving is the ability to detect and classify objects accurately. Machine learning models, such as convolutional neural networks (CNNs), are employed for object detection and classification tasks. Models like YOLO (You Only Look Once) and SSD (Single Shot MultiBox Detector) excel in recognizing and localizing objects in real-time, allowing vehicles to react swiftly to potential hazards.
Decision making and control
Machine learning algorithms assist in decision-making processes by analyzing the data collected from sensors and determining appropriate actions. Reinforcement learning and deep reinforcement learning techniques enable vehicles to learn optimal driving behaviors through trial and error, making decisions such as lane changes, speed adjustments, and navigating complex traffic scenarios.
Mapping and localization
Machine learning contributes to mapping and localization, crucial components for autonomous navigation. Algorithms process sensor data to create high-definition maps and precisely locate the vehicle within its environment using techniques like SLAM (Simultaneous Localization and Mapping).

Today we will look into a very new research problem which is Future Frame Prediction or Frame Extrapolation. In the domain of autonomous vehicles, predicting future frames aids in anticipating the movement of objects and understanding dynamic environments, crucial for safe navigation and decision-making.
Future frame prediction in autonomous driving
- Forecasting dynamic environments: Autonomous vehicles navigate through constantly changing environments filled with moving objects, pedestrians, other vehicles, and unpredictable scenarios. Future frame prediction aims to forecast the subsequent frames in a video sequence, allowing vehicles to anticipate the movements and behavior of surrounding entities.
- Real-time decision making: Accurate prediction of future frames empowers autonomous vehicles to make informed and timely decisions. By anticipating the trajectories of objects and understanding the evolving scene dynamics, these vehicles can plan safe and efficient paths, execute maneuvers, and avoid potential collisions.
- Enabling proactive safety measures: Future frame prediction plays a crucial role in enhancing the safety measures of autonomous driving systems. Anticipating the actions of other vehicles or pedestrians enables the vehicle to proactively take precautionary measures, thus mitigating risks and ensuring a safer driving experience.

In the ever-evolving landscape of machine learning and computer vision, the predictive power of models like Pix2Pix has ushered in a new era of advancements, particularly in the domain of future frame prediction. The ability to forecast future frames in video sequences has garnered significant attention, and Pix2Pix, an innovative generative adversarial network (GAN), has emerged as a powerful tool in this endeavor. In this blog, we'll delve into the concept of future frame prediction using Pix2Pix.
Unveiling Pix2Pix: A Brief Overview
Generative Adversarial Networks (GANs):
Pix2Pix is built upon the foundational architecture of GANs, a class of machine learning models that consist of two neural networks: a generator and a discriminator. These networks work in tandem, where the generator generates images or sequences, and the discriminator evaluates the authenticity of these outputs.
Conditional GANs for tmage translation:
Pix2Pix falls under the category of conditional GANs, specifically designed for image-to-image translation tasks. It excels in learning the mapping between input and output images, enabling tasks like style transfer, image colorization, and importantly, future frame prediction in video sequences.
Datasets
Several datasets are available for future frame prediction tasks, aiding researchers and practitioners in training and evaluating models designed for video sequence forecasting. Here are some notable datasets commonly used for future frame prediction:
- MovingMNIST: This dataset is an extension of the MNIST dataset, containing sequences of handwritten digit images. Each sequence consists of multiple frames depicting the movement of digits. It's commonly used for experimenting with early-stage video prediction models.
- UCF101: UCF101 is a widely used action recognition dataset containing video clips of 101 action categories. While primarily used for action recognition, it's also utilized for future frame prediction tasks due to its diverse set of actions and movement patterns.
- KTH Actions: This dataset contains videos of six human actions (walking, jogging, running, etc.) recorded under various scenarios. It's popular for action recognition and can be adapted for future frame prediction tasks.
- UCF Sports Actions: Another subset of the UCF dataset, UCF Sports Actions focuses on sports-related actions, providing a diverse range of movements and actions in sports scenes.
- YouTube-8M: YouTube-8M is a large-scale video dataset containing millions of YouTube video IDs along with audio and visual features. Although primarily used for video classification, segments of this dataset can be utilized for future frame prediction tasks.
- KITTI Dataset: This dataset is commonly used in autonomous driving research and contains various sequences recorded from a moving vehicle. It includes RGB images, depth, LiDAR, and GPS data, making it suitable for future frame prediction in dynamic driving scenarios.
- Cityscapes Dataset: Focused on urban scene understanding, the Cityscapes dataset comprises high-resolution images captured in street scenes. It contains annotated images and videos useful for tasks like semantic segmentation and future frame prediction in urban environments.
- BDD100K: This large-scale driving video dataset contains diverse driving scenes captured from various locations. It includes high-resolution videos with object annotations, making it suitable for future frame prediction in autonomous driving contexts.
Using Weights & Biases
Weights & Biases is a popular experiment tracking and model management tool that helps you visualize and compare your training runs in real-time. By integrating W&B into your workflow, you can:
- Track metrics such as loss, accuracy, or custom metrics (e.g., PSNR, SSIM) as your model trains.
- Visualize changes in real-time, allowing you to spot performance plateaus or rapid improvements quickly.
- Compare runs side-by-side to see which hyperparameters or model versions yield the best results.
- Organize artifacts like model weights, predictions, and logs in a cohesive and easily shareable dashboard.
- Collaborate more effectively with teammates by sharing interactive dashboards instead of raw log files.
Create a Weights & Biases' account and install it:
pip install wandb
Then login using
wandb login
Building the Pix2Pix model
In this script, we adapt the traditional Pix2Pix architecture to handle video frames by using 3D convolutions. Specifically:
Generator (U-Net with 3D Convs)
- Downsampling: Each downsample block uses Conv3D to extract features across spatial (height/width) and temporal (frame sequence) dimensions. We apply BatchNormalization for training stability and LeakyReLU for nonlinearity.
- Upsampling: Each upsample block uses Conv3DTranspose to reconstruct the resolution of the frames, assisted by skip connections (via Concatenate) that preserve details from earlier layers.
- Output: A final Conv3DTranspose layer with tanh activation produces normalized predicted frames.
Discriminator (PatchGAN with 3D Convs)
- We take both the real (input + target) and generated (input + model output) sequences and feed them into Conv3D layers to decide whether each local spatiotemporal patch is authentic or fabricated.
- ZeroPadding3D and BatchNormalization stabilize the deeper convolution stages, and the final output is a single-channel map indicating real vs. fake regions.
Using 3D convolutions ensures the network learns how objects move over time, improving future frame prediction beyond what a 2D-only approach could achieve.
import tensorflow as tffrom tensorflow.keras.layers import Conv3D, Conv3DTranspose, BatchNormalization, LeakyReLU, Dropout, ReLU, Input, Concatenate, ZeroPadding3Dfrom tensorflow.keras import Model, Sequentialoutput_channel = 3def downsample(filters,size,shape,apply_batchnorm = True):initializer = tf.random_normal_initializer(0., 0.02)result = Sequential()result.add(Conv3D(filters, size, strides=(1,2,2), padding= 'same', batch_input_shape= shape, kernel_initializer= initializer, use_bias= False))if apply_batchnorm:result.add(BatchNormalization())result.add(LeakyReLU())return resultdef upsample(filters, size, shape, apply_dropout = False):initializer = tf.random_normal_initializer(0.,0.02)result = Sequential()result.add(Conv3DTranspose(filters,size,strides=(1,2,2),padding='same',batch_input_shape=shape,kernel_initializer=initializer, use_bias=False))result.add(BatchNormalization())if apply_dropout:result.add(Dropout(0.5))result.add(ReLU())return resultdef build_generator():inputs = Input(shape=[7,256,256,3])down_stack = [downsample(64,(1,4,4),(None,7,256,256,3),apply_batchnorm=False),downsample(128,(1,4,4), (None,7, 128,128,64)),downsample(256, (1,4,4), (None,7,64,64,128)),downsample(512,(1,4,4),(None,7,32,32,256)),downsample(512,(1,4,4),(None,7,16,16,512)),downsample(512,(1,4,4),(None,7,8,8,512)),downsample(512,(1,4,4),(None,7,4,4,512)),downsample(512,(1,4,4),(None,7,2,2,512))]up_stack = [upsample(512,(1,4,4),(None,7,1,1,512),apply_dropout=True),upsample(512,(1,4,4),(None,7,2,2,1024),apply_dropout=True),upsample(512,(1,4,4),(None,7,4,4,1024),apply_dropout=True),upsample(512,(1,4,4),(None,7,8,8,1024)),upsample(256,(1,4,4),(None,7,16,16,1024)),upsample(128,(1,4,4),(None,7,32,32,512)),upsample(64,(1,4,4),(None,7,64,64,256))]initializer = tf.random_normal_initializer(0.,0.02)last = Conv3DTranspose(output_channel,(1,4,4),strides=(1,2,2),padding='same',kernel_initializer=initializer,activation='tanh')x = inputsskips = []for down in down_stack:x = down(x)skips.append(x)skips = reversed(skips[:-1])for up, skip, in zip(up_stack,skips):x= up(x)x = Concatenate()([x, skip])x = last(x)return Model(inputs = inputs, outputs = x)generator = build_generator()generator.summary()def downs(filters,size,apply_batch_norm = True):initializer = tf.random_normal_initializer(0.,0.02)result = Sequential()result.add(Conv3D(filters, size, strides=2,padding='same',kernel_initializer=initializer, use_bias=False))if apply_batch_norm:result.add(BatchNormalization())result.add(LeakyReLU())return resultdef build_discriminator():initializer = tf.random_normal_initializer(0.,0.02)inp = Input(shape=[7,256,256,3],name='input_img')tar = Input(shape=[7,256,256,3],name='target_img')x = Concatenate()([inp, tar])down1 = downs(64,(1,4,4),False)(x)down2 = downs(128,(1,4,4))(down1)down3 = downs(256,(1,4,4))(down2)zero_pad1 = ZeroPadding3D()(down3)conv = Conv3D(512,(1,4,4),strides=1,kernel_initializer=initializer,use_bias=False)(zero_pad1)batchnorm1 = BatchNormalization()(conv)leaky_relu = LeakyReLU()(batchnorm1)zero_pad2 = ZeroPadding3D()(leaky_relu)last = Conv3D(1,(1,4,4),strides=1,kernel_initializer=initializer)(zero_pad2)return Model(inputs = [inp,tar], outputs = last)discriminator = build_discriminator()discriminator.summary()
We can use different types of losses for image generation but here we use the modified BCE loss along with Adam optimizer.
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)LAMBDA = 100def gen_loss(disc_generated_output,gen_output,target):gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)loss = tf.abs(target - gen_output)l1_loss = tf.reduce_mean(loss)total_gen_loss = gan_loss + (LAMBDA*l1_loss)return total_gen_loss,gan_loss,l1_lossdef discriminator_loss(disc_real_output,disc_generated_output):real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)toal_disc_loss = real_loss + generated_lossreturn toal_disc_lossgenerator_optimizer = tf.keras.optimizers.Adam(2e-6,beta_1=0.5)discriminator_optimizer = tf.keras.optimizers.Adam(2e-6, beta_1=0.5)
Model training
Due to the nature of GAN training, we use a custom training function for training the generator and discriminator. We will use Weights & Biases logging to log our metrics.
@tf.functiondef train_step(input_img, target, epoch):with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:gen_output = generator(input_img, training = True)disc_real_output = discriminator([input_img,target], training = True)disc_generated_output = discriminator([input_img,gen_output], training = True)gen_total_loss, gen_gan_loss, gen_l1_loss = gen_loss(disc_generated_output,gen_output,target)disc_loss = discriminator_loss(disc_real_output,disc_generated_output)generator_gradients = gen_tape.gradient(gen_total_loss,generator.trainable_variables)discriminator_gradients = disc_tape.gradient(disc_loss,discriminator.trainable_variables)generator_optimizer.apply_gradients(zip(generator_gradients,generator.trainable_variables))discriminator_optimizer.apply_gradients((zip(discriminator_gradients,discriminator.trainable_variables)))run.log({"epoch":epoch,"gen_total_loss": gen_total_loss,"gen_gan_loss": gen_gan_loss,"gen_l1_loss": gen_l1_loss,"disc_loss": disc_loss})def fit(epochs):run = wandb.init(project="pix2pix", job_type="model-training")batch_size = 4for epoch in tqdm(range(epochs)):print("Epoch {}".format(epoch))for _,_,files in os.walk(path):random.shuffle(files)total_input = np.zeros((batch_size,7,256,256,3))total_output = np.zeros((batch_size,7,256,256,3))num = 0for k,file in enumerate(files):m = int(file.split('.')[0])if m <17850:xm = np.load(path + "/{}.jpg".format(m))xm_plus1 = np.load(path + "/{}.jpg".format(m+1))xm_plus2 = np.load(path + "/{}.jpg".format(m+2))xm_plus3 = np.load(path + "/{}.jpg".format(m+3))xm_plus4 = np.load(path + "/{}.jpg".format(m+4))xm_plus5 = np.load(path + "/{}.jpg".format(m+5))xm_plus6 = np.load(path + "/{}.jpg".format(m+6))xm_plus7 = np.load(path + "/{}.jpg".format(m+7))xm_plus8 = np.load(path + "/{}.jpg".format(m+8))xm_plus9 = np.load(path + "/{}.jpg".format(m+9))xm_plus10 = np.load(path + "/{}.jpg".format(m+10))xm_plus11 = np.load(path + "/{}.jpg".format(m+11))xm_plus12 = np.load(path + "/{}.jpg".format(m+12))xm_plus13 = np.load(path + "/{}.jpg".format(m+13))input = np.reshape(np.concatenate([xm,xm_plus1,xm_plus2,xm_plus3,xm_plus4,xm_plus5,xm_plus6]),(7,256,256,3)).astype(np.float32)output = np.reshape(np.concatenate([xm_plus7,xm_plus8,xm_plus9,xm_plus10,xm_plus11,xm_plus12,xm_plus13]),(7,256,256,3)).astype(np.float32)total_input[num] = input/255.0total_output[num] = output/255.0num = num + 1if num == batch_size:num = 0total_input = tf.convert_to_tensor(total_input,dtype=tf.float32)total_output = tf.convert_to_tensor(total_output,dtype=tf.float32)train_step(total_input, total_output, epoch)total_input = np.zeros((batch_size,7,256,256,3))total_output = np.zeros((batch_size,7,256,256,3))run.finish()fit(epochs)
Training Metrics
Results
Following are some of the results including 6 input frames and the succeeding predicted frame.

Challenges and future directions
Long-term prediction
Handling long-term predictions remains a challenge, requiring advancements in capturing and modeling intricate temporal dependencies accurately.
Robustness and generalization:
Ensuring robustness and generalization of Pix2Pix models across diverse scenarios and unpredictable environments remains a focus for ongoing research.
Real-world deployment:
Further exploration is needed to deploy Pix2Pix-based future frame prediction models in real-world applications, addressing scalability, efficiency, and reliability.
References
Add a comment
Iterate on AI agents and models faster. Try Weights & Biases today.