I was trying to run a model on multiple gpu with mirror strategy
of tensorflow.
I used a custom loss function like this:
def mae(y_true, y_pred):
# y_true, y_pred shape = (B, L)
loss = tf.keras.metrics.mean_absolute_error(y_true, y_pred)
# loss shape = (B,)
return loss
class custom_loss(tf.keras.losses.Loss):
def __init__(self, BATCH_SIZE = 1, **kwargs):
super(custom_loss, self).__init__(**kwargs)
self.BATCH_SIZE = BATCH_SIZE
def call(self, y_true, y_pred):
# y_true, y_pred shape = (B, L, 1)
loss = mae(tf.squeeze(y_true, [-1]), tf.squeeze(y_pred, [-1]))
loss = tf.reduce_sum(loss) * (1. / self.BATCH_SIZE)
return loss
def get_config(self):
config = super().get_config().copy()
config.update({'BATCH_SIZE': self.BATCH_SIZE})
return config
with mirror strategy I train the model like this:
def get_compiled_model(args, BATCH_SIZE):
# Make a simple 2-layer densely-connected neural network.
model = MyCustomModel(input_shape=(args.L, 1))
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=args.learning_rate, beta_1=args.beta_1, beta_2=args.beta_2, epsilon=args.epsilon), loss = custom_loss(BATCH_SIZE))
return model
def run_training(args, steps, model = None):
# Create a MirroredStrategy.
strategy = tf.distribute.MirroredStrategy()
BATCH_SIZE_PER_REPLICA = args.batch_size
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
# Open a strategy scope and create/restore the model
with strategy.scope():
if isinstance(model, type(None)):
model = get_compiled_model(args, BATCH_SIZE)
train_dataset, test_dataset, valid_dataset = get_dataset(args, BATCH_SIZE)
callbacks = [
tf.keras.callbacks.ModelCheckpoint(
filepath=os.path.join(args.checkpoints_dir , steps + "_epoch-{epoch:03d}_loss-{loss:.4f}"), save_best_only = True
)
]
model.fit(train_dataset, epochs=args.epochs, callbacks=callbacks, validation_data = valid_dataset, steps_per_epoch = (None if args.steps_per_epoch == -1 else args.steps_per_epoch), validation_steps = (None if args.steps_per_epoch == -1 else args.steps_per_epoch), verbose = 1)
But if I run this on 4 GPU, my loss value becomes 1/4 times than the loss I get when run on single GPU. Does it fail to sum up the different losses from the different GPUs?