INT8 LSTM example¶
This is an example of an 8-bit integer (INT8) quantized TensorFlow Keras model using post-training quantization. In other words, this model can be trained using normal floating-point training, but will be able to run in INT8 mode at inference time. Using post-training quantization requires example representative data to be available. For more information, see the TensorFlow documentation about this subject.
First we define our model and a representative data-set as follows:
import tensorflow as tf NUM_SAMPLES = 10 # e.g. 10 letters in a word or 10 timestamps SAMPLE_SIZE = 64 # size of a single sample, e.g. an embedding of size 64 def dataset_example(num_samples: int = 100): """Placeholder for a representative data-set. For best quantization performance, replace this with a few examples from your own data-set, the more the better. This should include any pre-processing needed.""" for _ in range(num_samples): shape = (1, NUM_SAMPLES, SAMPLE_SIZE) yield [tf.random.uniform(shape, minval=-1, maxval=1)] def model() -> tf.keras.Model: """Example of a simple single-directional LSTM. Embedding layers are not supported and will have to be added as pre-processing steps manually. See the official TFLite documentation for more information and constraints: https://www.tensorflow.org/lite/models/convert/rnn. See in-line comments below for recommendations about tf.keras.layers.LSTM().""" in_layer = tf.keras.layers.Input((NUM_SAMPLES, SAMPLE_SIZE), batch_size=1) x = in_layer # Recommended settings for tf.keras.layers.LSTM: # 1) Set unroll=False (the default) for much better speed, RAM, and ROM # 2) Set time_major=False (the default) for better speed, RAM, and ROM # 3) Set activation='tanh' (the default) for compatibility # 4) Set recurrent_activation='sigmoid' (the default) for compatibility x = tf.keras.layers.LSTM(units=16, return_sequences=True)(x) x = tf.keras.layers.Flatten()(x) x = tf.keras.layers.Dense(8)(x) return tf.keras.Model(in_layer, x, name="LSTM")
After we have trained the above LSTM example model (not shown in this example), we can convert it to INT8 quantized TFLite format as follows:
from pathlib import Path from typing import Callable import tensorflow as tf def convert_model(model: tf.keras.Model, dataset_gen: Callable) -> bytes: converter = tf.lite.TFLiteConverter.from_keras_model(model) converter.experimental_new_converter = True converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.representative_dataset = dataset_gen converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] converter.inference_input_type = tf.int8 converter.inference_output_type = tf.int8 return converter.convert() LSTM_model = model() # (now train the model or load some weights) int8_LSTM_model = convert_model(LSTM_model, dataset_example) Path("LSTM.tflite").write_bytes(int8_LSTM_model)
The resulting LSTM.tflite is now ready to be used with the inference engine.