How to reuse a computation graph with different inputs?

63 views Asked by At

I'd like to use the same computation graph with different inputs in TensorFlow. Unfortunately, I didn't write the model function and I would like to avoid modifying it if at all possible.

def model_fn(features, labels, mode):
    # some code here
    return loss, train_op

I'm writing a federated learning algorithm and I want to train the same model on n different datasets, without having n clients, so I planned on using a single computation graph and different inputs.

I planned on doing this like so:

with tf.variable_scope("client", reuse=tf.AUTO_REUSE):
    for _ in range(num_of_clients):
        features, labels = input_fn()
        client_model = model_fn(features, labels, "train")

I had hoped that client_model would reuse all of the shared variables.

  • I can confirm that it doesn't with tf.global_variables().
  • I also know that if model_fn() used tf.get_variable(), this would work, but it doesn't.

How would I do this in TensorFlow without modifying the model_fn (otherwise I would just use this answer)?

0

There are 0 answers