distributed training with tensorflow on 'x' gpu makes loss 1/x

193 views Asked by At

I was trying to run a model on multiple gpu with mirror strategy of tensorflow.

I used a custom loss function like this:

def mae(y_true, y_pred):
    # y_true, y_pred shape  = (B, L)
    loss = tf.keras.metrics.mean_absolute_error(y_true, y_pred) 
    # loss shape = (B,)
    return loss
class custom_loss(tf.keras.losses.Loss):
    def __init__(self, BATCH_SIZE = 1, **kwargs):
        super(custom_loss, self).__init__(**kwargs)
        self.BATCH_SIZE = BATCH_SIZE

    def call(self, y_true, y_pred):
        # y_true, y_pred shape = (B, L, 1)
        loss = mae(tf.squeeze(y_true, [-1]), tf.squeeze(y_pred, [-1]))
        loss = tf.reduce_sum(loss) * (1. / self.BATCH_SIZE)
        return loss

    def get_config(self):
        config = super().get_config().copy()
        config.update({'BATCH_SIZE': self.BATCH_SIZE})
        return config

with mirror strategy I train the model like this:

def get_compiled_model(args, BATCH_SIZE):
    # Make a simple 2-layer densely-connected neural network.
    model = MyCustomModel(input_shape=(args.L, 1))
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=args.learning_rate, beta_1=args.beta_1, beta_2=args.beta_2, epsilon=args.epsilon), loss = custom_loss(BATCH_SIZE))
    return model

def run_training(args, steps, model = None):
    # Create a MirroredStrategy.
    strategy = tf.distribute.MirroredStrategy()
    BATCH_SIZE_PER_REPLICA = args.batch_size
    BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

    # Open a strategy scope and create/restore the model
    with strategy.scope():
        if isinstance(model, type(None)):
            model = get_compiled_model(args, BATCH_SIZE)
        train_dataset, test_dataset, valid_dataset = get_dataset(args, BATCH_SIZE)
        callbacks = [
            tf.keras.callbacks.ModelCheckpoint(
                filepath=os.path.join(args.checkpoints_dir , steps + "_epoch-{epoch:03d}_loss-{loss:.4f}"), save_best_only = True
            )
        ]
        model.fit(train_dataset, epochs=args.epochs, callbacks=callbacks, validation_data = valid_dataset, steps_per_epoch = (None if args.steps_per_epoch == -1 else args.steps_per_epoch), validation_steps = (None if args.steps_per_epoch == -1 else args.steps_per_epoch), verbose = 1)

But if I run this on 4 GPU, my loss value becomes 1/4 times than the loss I get when run on single GPU. Does it fail to sum up the different losses from the different GPUs?

0

There are 0 answers