State of the art machine learning models are often bulky which often makes them inefficient for deployment in resource-constrained environments, like mobile phones, Raspberry Pis, microcontrollers, and so on. Even if you think that you might get around this problem by hosting your model on the Cloud and using an API to serve results – think of constrained environments where internet bandwidths might not be always high, or where data must not leave a particular device.
We need a set of tools that make the transition to on-device machine learning seamless. In this report, I will show you how TensorFlow Lite (TF Lite) can really shine in situations like this. We'll cover model optimization strategies and quantization techniques supported by TensorFlow.
Thanks to Arun, Khanh, and Pulkit (Google) for sharing incredibly useful tips for this report.
Generally, our machine learning models operate in
float32 precision format. All the model parameters are stored in this precision format, which often leads to heavier models. The heaviness of a model has a direct correlation to the speed at which the model makes predictions. So, it might occur to you naturally that what if we could reduce the precision in which our models would operate, we could cut down on prediction times. That is what quantization does - it reduces the precision to lower forms like float16, int8, etc to represent the parameters of a model.
Quantization can be applied to a model in two flavors -
We will see both these flavors in this report. Let's get started!
All of the experiments that we do in this report were performed on Colab. I used the flowers dataset for the experiments and fine-tuned a pre-trained MobileNetV2 network to start off with. Here's the code that defines the network architecture -
# Load the MobileNetV2 model but exclude the classification layers EXTRACTOR = MobileNetV2(weights="imagenet", include_top=False, input_shape=(224, 224, 3)) # We will set it to both True and False EXTRACTOR.trainable = True # Construct the head of the model that will be placed on top of the # the base model class_head = EXTRACTOR.output class_head = GlobalAveragePooling2D()(class_head) class_head = Dense(512, activation="relu")(class_head) class_head = Dropout(0.5)(class_head) class_head = Dense(5, activation="softmax")(class_head) # Create the new model classifier = Model(inputs=EXTRACTOR.input, outputs=class_head) # Compile and return the model classifier.compile(loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
The networks were trained for 10 epochs with a batch size of 32.
After you have trained a model in
tf.keras, the quantization part is just a matter of a few lines of code. So, the way you would do that is as follows -
converter = tf.lite.TFLiteConverter.from_keras_model(non_qat_flower_model) converter.optimizations = [tf.lite.Optimize.DEFAULT] quantized_tflite_model = converter.convert()
You are first loading your model into a
TFLiteConverter converter class, then specifying an optimization policy, and finally, you ask TFLite to convert your model with the optimization policy. Serializing the converted TF Lite file is straight-forward -
f = open("normal_flower_model.tflite", "wb") f.write(quantized_tflite_model) f.close()
This form of quantization is also referred to as post-training dynamic range quantization. It quantizes the weights of your model to 8-bits of precision. Here you can find more details about this and other post-training quantization schemes.
TF Lite allows us to specify a number of different configurations when converting our models. We saw once such configuration in the aforementioned code, where we specified the optimization policy.
tf.lite.Optimize.DEFAULT, there are other two policies available -
tf.lite.Optimize.OPTIMIZE_FOR_LATENCY. From the names, you can see that, based on the choice of policy, TF Lite will try to optimize the models accordingly.
We can specify other things like -
Learn more about the
TFLiteConverter class here. It's important to note that these different configuration options allow us to maintain trade-offs between a model's prediction speed and it's accuracy. Here, you can find a number of trade-offs with respect to different post-training quantization schemes available in TF Lite.
Below we can see some useful statistics on this converted model.
A good first approach here is to train your model in a way in which it would learn to compensate for the information loss that might be induced from quantization. With quantization-aware training we can do just that. To train our network in a quantization-aware manner, we just add the following lines of code -
import tensorflow_model_optimization as tfmot qat_model = tfmot.quantization.keras.quantize_model(your_keras_model)
Now, you can train
qat_model in the same way you would train a
tf.keras model. Here you can find a comprehensive coverage of QAT.
Below, we can see that this quarantization aware model does slightly better than our previous model.
In terms of model size, the QAT model is similar to the non-QAT model:
In the following table, we see that the quantized version of the QAT model indeed performs better than the previous model.
To quantize our models to float precision, we just need to discard this line -
converter.optimizations = [tf.lite.Optimize.DEFAULT]. Note that, float16 quantization is also supported in TensorFlow Lite. This policy is particularly helpful if you were to take advantage of GPU delegates. In the table below, we can see the size and accuracy of the models quantized using this scheme.
There are other post-training quantization techniques available as well, such as full integer quantization, float16 quantization, etc. This is where you can learn more about them. Keep in mind that the full integer quantization scheme might not always be compatible with a QAT model.
There are a number of SoTA pre-trained TF Lite models hosted for the developers to use for their applications and they can be found here:
For mobile developers who are looking to integrate machine learning in their applications, there are a number of example applications in TF Lite worth checking out. TensorFlow Lite also provides tooling for embedded systems and microcontrollers and you can learn more about it from here.
If you'd like to reproduce the results of this analysis, you can –