I'm currently trying to use Keras' Quantization Aware Training, specifically because I need to do 8bit inference on a low-precision device. For this reason, I need to fold the batch norm onto the Convolution to avoid having the 32-bit moving mean and variance. The sample code I'm starting with is the following (tf1.15, tensorflow-model-optimization 0.6.0):
model = tf.keras.Sequential([
tf.keras.layers.InputLayer(input_shape=(224, 224, 3)),
tf.keras.layers.Conv2D(filters=3, kernel_size=(3, 3)),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('relu'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(1000)
])
quantize_model = tfmot.quantization.keras.quantize_model
# q_aware stands for for quantization aware.
q_aware_model = quantize_model(model)
# `quantize_model` requires a recompile.
q_aware_model.compile(optimizer='adam',
loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=smooth),
metrics=['accuracy'])
q_aware_model.summary()
The documentation states that 'Conv2D+BN+ReLU' should have the BatchNorm folded but that isn't the case in the .h5 file produced.