I'd like to use the same computation graph with different inputs in TensorFlow. Unfortunately, I didn't write the model function and I would like to avoid modifying it if at all possible.
def model_fn(features, labels, mode):
# some code here
return loss, train_op
I'm writing a federated learning algorithm and I want to train the same model on n different datasets, without having n clients, so I planned on using a single computation graph and different inputs.
I planned on doing this like so:
with tf.variable_scope("client", reuse=tf.AUTO_REUSE):
for _ in range(num_of_clients):
features, labels = input_fn()
client_model = model_fn(features, labels, "train")
I had hoped that client_model would reuse all of the shared variables.
- I can confirm that it doesn't with
tf.global_variables()
. - I also know that if
model_fn()
usedtf.get_variable()
, this would work, but it doesn't.
How would I do this in TensorFlow without modifying the model_fn
(otherwise I would just use this answer)?