I'm writing CTGAN code and want it to train in a distributed way. Therefore I'm using tf.distribute.Strategy.mirroredstrategy() In the tensorflow docs tutorial I'm following, it is mentioned that you should call your train_step code from a function called distribute_trainstep(), and decorate that with a tf.function. like so:
@tf.function
def distributed_train_step(dataset_inputs):
per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
axis=None)
This is straightforward, but decorating all within train_step in a tf.function renders all numpy code within the train_step useless. What should I do? Is there an alternative, by only wrapping functions within train_step selectively? Or will I have to replace all numpy operations with tensorflow's?