Mixed precision training with tf.keras

Sayak Paul

In this article, we are going to see how to incorporate mixed precision (MP) training in your tf.keras training workflows. Mixed precision training was proposed by NVIDIA in this paper. It has allowed us to train large neural networks significantly faster with zero to very little decrease in the performance of the networks. Here’s what we are gonna cover -

If you want to get into the details of mixed precision training, then I highly suggest the following resources:

Let’s get started!

Incorporating mixed precision training in tf.keras (TensorFlow 2.0)

TensorFlow 2.0 offers the following options to help you easily incorporate mixed precision training -

In my experience, I have found the last two ones to be more performant. There are some configurations needed, however, in order to activate mixed precision training. We will see them in a later section. But before that, let’s discuss some pointers to keep mind while using mixed precision training.

Things to remember while using MP training

Please note that you will not be able to attain considerable performance improvements with MP training out of the box. There are some pointers we need to keep in mind when using MP training -

Hands-on mixed precision training with tf.keras

In this section, we will see some hands-on examples for using mixed precision training with tf.keras. You can find the full-length experiments in this repo.

Introduction to the dataset

The data for my experiments came from this Analytics Vidhya Hackathon. You’re given a set of images like the following and you need to predict the category of a given ship -

The labels of the images were given as the following encodings -

{'Cargo': 1,

'Military': 2,

'Carrier': 3,

'Cruise': 4,

'Tankers': 5}

There are 6252 images in train set and 2680 images in test set. Unfortunately, the images in the test set were completely unlabeled (as it should be in a Hackathon). For the experiments, I only used those 6252 images.

The dataset comes in the following format -

├── train

│   ├── images [8932 entries]

│   └── train.csv

└── test_ApKoW4T.csv

where, train.csv and test_ApKoW4T.csv contain the names of the training and testing images respectively.

Now that we have a fair introduction to the dataset, we can proceed towards exploring the hands-on examples.

Setting explicit policies for MP training

tf.keras.mixed_precision.experimental.set_policy allows us to set the default policy for the layers of a network. Policy here refers to the dtype of a specific layer. There are multiple ways to set policies in tf.keras for the layers -

policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')

model = tf.keras.models.Sequential(


   tf.keras.layers.Dense(10, dtype=policy),

   tf.keras.layers.Dense(10, dtype=policy),

   # Softmax should be done in float32 for numeric stability.

   tf.keras.layers.Activation('Softmax', dtype='float32')


model.fit(...)  # Train `model`

(Code taken from this example)

Now, to be able to use the policies for mixed precision, you need to enable the XLA compiler (an introduction to XLA compilers is available here) like so - tf.config.optimizer.set_jit(True). Note that, this should be done per session basis i.e. if you plan to use XLA compiler (in most of the cases you really should) you need to enable it for each of the new sessions.

It’s also a good idea to clear any existing session before you start your experiments to prevent unforeseen issues -

from tensorflow.keras import backend as K


After enabling the XLA compiler, we set the default policy of the layers like so - tf.keras.mixed_precision.experimental.set_policy('mixed_float16'). We can now define our model -

def create_model(img_size=(224,224), num_class=5, train_base=True):

   # Accept float16 image inputs

   input_layer = Input(shape=(img_size[0],img_size[1],3), dtype=tf.float16)

   base = ResNet50(input_tensor=input_layer,



   base.trainable = train_base

   x = base.output

   x = GlobalAveragePooling2D()(x)

   # softmax only accepts float32 - need to manually cast (likely a bug)

   preds = Dense(num_class, activation="softmax", dtype=tf.float32)(x)

   return Model(inputs=input_layer, outputs=preds)

We added a (multi) classification head on top of a pre-trained ResNet50 network. Note that the inputs of the model should in float16. Also, take a look at the dtype of the Dense layer. Now, when I trained the network with the above-mentioned data on my GCP Notebook instance (consisting of a Tesla T4), I got the following -

Train for 78 steps

Epoch 1/5

78/78 [==============================] - 71s 916ms/step - loss: 0.3686 - accuracy: 0.7503

Epoch 2/5

78/78 [==============================] - 18s 227ms/step - loss: 0.2259 - accuracy: 0.7822

Epoch 3/5

78/78 [==============================] - 18s 229ms/step - loss: 0.1535 - accuracy: 0.7880

Epoch 4/5

78/78 [==============================] - 18s 230ms/step - loss: 0.1268 - accuracy: 0.7979

Epoch 5/5

78/78 [==============================] - 18s 230ms/step - loss: 0.1180 - accuracy: 0.7928

As a sanity check, I trained the same network without the mixed precision configurations and I got the following results -

Train for 78 steps

Epoch 1/5

78/78 [==============================] - 80s 1s/step - loss: 0.3619 - accuracy: 0.7359

Epoch 2/5

78/78 [==============================] - 59s 756ms/step - loss: 0.2056 - accuracy: 0.7683

Epoch 3/5

78/78 [==============================] - 60s 770ms/step - loss: 0.1513 - accuracy: 0.7761

Epoch 4/5

78/78 [==============================] - 61s 778ms/step - loss: 0.1305 - accuracy: 0.7835

Epoch 5/5

78/78 [==============================] - 61s 788ms/step - loss: 0.1017 - accuracy: 0.7869

You can see the improvement in the training time! Let’s visualize some of the most important metrics here:

The above set of graphs provides us with information about the following metrics that too from the different experiments (you can see the name of the experiments with a bit struggle) -

It can be clearly seen that mixed precision training indeed allows to train our networks faster that too without any loss in the performance. Note that “training_time” is not tracked by W&B by default but it is as easy as the following block of code:

import time

start = time.time()



training_time = time.time() - start


You can log many other things like tables, images, audios etc. If you are interested to explore that space, please check out the documentation here. W&B also allows us to generate the above plots within our Jupyter Notebooks with the magic command %%wandb (read more here) but I like to keep these plots separated from my notebooks.

Additionally, I created the following scatter plot to capture the trends in “training time vs. loss” using the “Add a visualization” button you get on your run page.

Creating the scatter plot was a matter of a few keystrokes and you can easily.

As you would expect, the graph is interactive as well.

Let’s now explore the second option for doing MP training in tf.keras.

Loss scaling for MP training

Loss scaling is an important concept in MP training as it prevents the numerical underflow that might happen to lower precision in the computations. To be able to use tf.keras.mixed_precision.experimental.LossScaleOptimizer, you need to enable mixed precision by tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True}) along enabling the XLA compiler. Be sure to review some of the important things you may need to follow if you are using loss scaling with a custom training loop. We could use the same model we saw in the previous section and in order to plug in loss scaling we would do the following -

opt = Adam(learning_rate=1e-4)

opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt,  





And that’s it! You should get similar performance as you got with policies in this case as well.

W&B for comparing our experiments

Mixed precision does not only result in faster training - it is supposed to reduce memory footprint of the models as well. Because you are not using double precision or single precision fully. But this might not be that evident since we are dealing with less data here. Had it been a larger dataset (for example, say dataset of 500 MB) the difference would have been very evident. But it’s always very practical to see the memory footprint of your machine learning models in general. W&B makes it super easy to visualize that. On my project page, for individual runs like this, you get a tab called “System”. It provides you with a bunch of information like so -

Apart from the above metrics related to your system, you also get information about GPU memory usage in the bottom. In my experiment, I used the data API for TensorFlow which allows computations related to data loading, preprocessing etc. to be performed on your GPU (if available). This truly sets tf.data apart from other alternatives like ImageDataGenerator as it does not allow for GPU computation. You can use the graphs on GPU memory usage for this kind of comparisons.

The following figure shows the GPU usage from different experiments -

As you can see the different flavors of mixed precision training tend to follow similar patterns in terms of GPU usage. When I did not use mixed precision, the GPU was actually utilized better. This might hint the developers to delve further and figure out ways to improve the GPU utilization. But there is always trade-off between GPU utilization and model performance and it changes from project to project.

Following are some comprehensive graphs that represent all the necessary stuff from the conducted experiments -

The legends of the graphs are the different runs (one run = one experiment). It’s very important to turn on the XLA compiler when using mixed precision training. As can be seen from the above figure, without XLA compiler if mixed precision training is used model training will be more time consuming.

As a machine learning practitioner, you will be expected to care about aspects like this more often than not.

For each of the different runs, we get a dedicated logs page as well where we get the local training log -

You can also initialize W&B (wandb.init()) just before you are starting training process and it will appear under Runtime tab.

Thank you!

Thanks you for sticking together with me till the end. I hope you will benefit from this article and incorporate mixed precision training in your own experiments. Don’t forget to let me know the results!

I would like to thank Lavanya from the W&B team who carefully reviewed the codes and the article itself. Her feedback was immensely productive.

Join our mailing list to get the latest machine learning updates.