INT8 MobileNetV2 example¶
This is an example of an 8-bit integer (INT8) quantized TensorFlow Keras model using post-training quantization. In other words, this model can be trained using normal floating-point training, but will be able to run in INT8 mode at inference time. Using post-training quantization requires example representative data to be available. For more information, see the TensorFlow documentation about this subject.
First we define our model and a representative data-set as follows:
import tensorflow as tf
IMAGE_SHAPE = (224, 224, 3) # example 224x224 RGB image
def dataset_example(num_samples: int = 100):
"""Placeholder for a representative data-set. For best quantization
performance, replace this with a few examples from your own data-set, the
more the better. This should include any pre-processing needed."""
for _ in range(num_samples):
yield [tf.random.uniform(shape=(1, *IMAGE_SHAPE), minval=-1, maxval=1)]
def model() -> tf.keras.Model:
"""Example instance of a MobileNetV2 network from the Keras applications
package. More information about the arguments and other similar networks can
be found in the Keras docs: https://keras.io/api/applications/mobilenet/."""
return tf.keras.applications.MobileNetV2(
input_shape=IMAGE_SHAPE,
alpha=0.35,
include_top=True,
weights="imagenet",
input_tensor=None,
pooling=None,
classes=1000,
classifier_activation=None,
)
After we have trained the above MobileNetV2 example model (not shown in this example), we can convert it to INT8 quantized TFLite format as follows:
from pathlib import Path
from typing import Callable
import tensorflow as tf
def convert_model(model: tf.keras.Model, dataset_gen: Callable) -> bytes:
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.experimental_new_converter = True
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = dataset_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
return converter.convert()
MobileNetV2_model = model()
# (now train the model or load some weights)
int8_MobileNetV2_model = convert_model(MobileNetV2_model, dataset_example)
Path("MobileNetV2.tflite").write_bytes(int8_MobileNetV2_model)
The resulting MobileNetV2.tflite is now ready to be used with the inference engine.