I am trying to find out, how exactly does BatchNormalization layer behave in TensorFlow. I came up with the following piece of code which to the best of my knowledge should be a perfectly valid keras model, however the mean and variance of BatchNormalization doesn't appear to be updated.
From docs https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization
in the case of the BatchNormalization layer, setting trainable = False on the layer means that the layer will be subsequently run in inference mode (meaning that it will use the moving mean and the moving variance to normalize the current batch, rather than using the mean and variance of the current batch).
I expect the model to return a different value with each subsequent predict call. What I see, however, are the exact same values returned 10 times. Can anyone explain to me why does the BatchNormalization layer not update its internal values?
import tensorflow as tf
import numpy as np
if __name__ == '__main__':
np.random.seed(1)
x = np.random.randn(3, 5) * 5 + 0.3
bn = tf.keras.layers.BatchNormalization(trainable=False, epsilon=1e-9)
z = input = tf.keras.layers.Input([5])
z = bn(z)
model = tf.keras.Model(inputs=input, outputs=z)
for i in range(10):
print(x)
print(model.predict(x))
print()
I use TensorFlow 2.1.0
Okay, I found the mistake in my assumptions. The moving average is being updated during training not during inference as I thought. This makes perfect sense, as updating the moving averages during inference would likely result in an unstable production model (for example a long sequence of highly pathological input samples [e.g. such that their generating distribution differs drastically from the one on which the network was trained] could potentially bias the network and result in worse performance on valid input samples).
The trainable parameter is useful when you're fine-tuning a pretrained model and want to freeze some of the layers of the network even during training. Because when you call
model.predict(x)
(or evenmodel(x)
ormodel(x, training=False)
), the layer automatically uses the moving averages instead of batch averages.The code below demonstrates this clearly
In the end, I found that the documentation (https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization) mentions this:
Hopefully this will help someone with similar misunderstanding in the future.