Skip to content

INT8 LeNet example

This is an example of an 8-bit integer (INT8) quantized TensorFlow Keras model using post-training quantization. In other words, this model can be trained using normal floating-point training, but will be able to run in INT8 mode at inference time. Using post-training quantization requires example representative data to be available. For more information, see the TensorFlow documentation about this subject.

First we define our model and a representative data-set as follows:

import tensorflow as tf

IMAGE_SHAPE = (28, 28, 1)  # Example 28x28 grayscale image (e.g. MNIST)

def dataset_example(num_samples: int = 100):
    """Placeholder for a representative data-set. For best quantization 
    performance, replace this with a few examples from your own data-set, the 
    more the better. This should include any pre-processing needed."""
    for _ in range(num_samples):
        yield [tf.random.uniform(shape=(1, *IMAGE_SHAPE), minval=-1, maxval=1)]

def model() -> tf.keras.Model:
    """Example convolutional neural network (CNN) based on the original LeNet-5
    architecture from Yann LeCun, see https://en.wikipedia.org/wiki/LeNet."""
    input_layer = tf.keras.layers.Input(shape=IMAGE_SHAPE, batch_size=1)
    x = input_layer
    x = tf.keras.layers.Conv2D(6, kernel_size=(3, 3), activation="relu")(x)
    x = tf.keras.layers.AveragePooling2D()(x)
    x = tf.keras.layers.Conv2D(16, kernel_size=(3, 3), activation="relu")(x)
    x = tf.keras.layers.AveragePooling2D()(x)
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(units=120, activation="relu")(x)
    x = tf.keras.layers.Dense(units=84, activation="relu")(x)
    x = tf.keras.layers.Dense(units=10, activation="softmax")(x)
    return tf.keras.Model(input_layer, x, name="LeNet")

After we have trained the above LeNet example model (not shown in this example), we can convert it to INT8 quantized TFLite format as follows:

from pathlib import Path
from typing import Callable
import tensorflow as tf

def convert_model(model: tf.keras.Model, dataset_gen: Callable) -> bytes:
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.experimental_new_converter = True
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.representative_dataset = dataset_gen
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
    converter.inference_input_type = tf.int8
    converter.inference_output_type = tf.int8
    return converter.convert()

LeNet_model = model()
# (now train the model or load some weights)
int8_LeNet_model = convert_model(LeNet_model, dataset_example)
Path("LeNet.tflite").write_bytes(int8_LeNet_model)

The resulting LeNet.tflite is now ready to be used with the inference engine.