Tensorflow clip values in collection?

533 views Asked by At

I am trying to clip all training variables for my discriminators in my network.

I get the variables for the discriminators like this:

A_d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'A_d_')
B_d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'A_B_')
discriminatorVars = self.A_d_vars + self.B_d_vars 

Now, if I try to do this discriminatorVars.assign(tf.clip_by_value(discriminatorVars, 0.01, 0.1)) to clip all the values to [0.01, 0.1] it won't work as the vars are python lists not tensors.

I also tried this, but it doesn't work:

self.sess.run(tf.map_fn(lambda var: var.assign(tf.clip_by_value(var, 0.01, 0.1)), var_list))

It says that list object has no assign method.

Currently I loop through all the variables in the list and call self.sess.run(var.assign(tf.clip_by_value(var, 0.01, 0.1)))
The problem is that is very slow.

How can I batch-update the collections so their values will be clipped?

1

There are 1 answers

1
Peter Hawkins On

Try making a list of the assign ops you want to do, and use tf.group (https://www.tensorflow.org/api_docs/python/tf/group) to group them. Pass the tf.group operator to sess.run.

Session.run() can have a non-trivial overhead, so you want to do all the updates in a single Session.run() call.