I am trying to make sure that I am incorporating batch normalization layers into a model correctly.
The code snippet below illustrates what I am doing.
- Is this an appropriate use of batch normalization?
- At inference time, how can I access the moving averages in each batch normalization layer to make sure they are being loaded?
List item
import tensorflow.v1.compat as tf
from model import Model
# Sample batch normalization layer in the Model class
x_preBN = ...
x_postBN = tf.layers.batch_normalization(inputs=x_preBN,
center=True,
scale=True,
momentum=0.9,
training=(self.mode == 'train'))
# During training:
model = Model(mode='train')
extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.Session() as sess:
for it in range(max_iterations):
# Training step + update of BN moving statistics
sess.run([train_step, extra_update_ops], feed_dict=...)
# Store checkpoint
if ii % num_checkpoint_steps == 0:
saver.save(sess,
os.path.join(model_dir, 'checkpoint'),
global_step=it)
# During inference:
model = Model(mode='eval')
with tf.Session() as sess:
saver.restore(sess, os.path.join(model_dir, 'checkpoint-???'))
acc = sess.run(model.accuracy, feed_dict=...)
Once the model has been instantiated, a list of all global variables can be obtained as
The batch normalization variables for a specific layer look like this: gamma and beta are trainable, whereas the moving statistics are not (and hence the need to specify the extra_update_ops during training).
They can be accessed as usual: