I'm writing CTGAN code and want it to train in a distributed way. Therefore I'm using tf.distribute.Strategy.mirroredstrategy() In the tensorflow docs tutorial I'm following, it is mentioned that you should call your train_step code from a function called distribute_trainstep(), and decorate that with a tf.function. like so:

@tf.function
def distributed_train_step(dataset_inputs):
  per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
  return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                         axis=None)

This is straightforward, but decorating all within train_step in a tf.function renders all numpy code within the train_step useless. What should I do? Is there an alternative, by only wrapping functions within train_step selectively? Or will I have to replace all numpy operations with tensorflow's?

0

There are 0 answers