I am trying to clip all training variables for my discriminators in my network.
I get the variables for the discriminators like this:
A_d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'A_d_')
B_d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'A_B_')
discriminatorVars = self.A_d_vars + self.B_d_vars
Now, if I try to do this
discriminatorVars.assign(tf.clip_by_value(discriminatorVars, 0.01, 0.1))
to clip all the values to [0.01, 0.1] it won't work as the vars are python lists not tensors.
I also tried this, but it doesn't work:
self.sess.run(tf.map_fn(lambda var: var.assign(tf.clip_by_value(var, 0.01, 0.1)), var_list))
It says that list
object has no assign
method.
Currently I loop through all the variables in the list and call self.sess.run(var.assign(tf.clip_by_value(var, 0.01, 0.1)))
The problem is that is very slow.
How can I batch-update the collections so their values will be clipped?
Try making a list of the assign ops you want to do, and use
tf.group
(https://www.tensorflow.org/api_docs/python/tf/group) to group them. Pass thetf.group
operator tosess.run
.Session.run()
can have a non-trivial overhead, so you want to do all the updates in a singleSession.run()
call.