How to freeze specific nodes in a tensorflow variable while training?

2.4k views Asked by At

Currently I am having trouble in making a few elements in a variable as non-trainable. It implies that given a variable such as x,

x= tf.Variable(tf.zeros([2,2]))

I wish to train only x[0,0] and x[1,1] while keeping x[0,1] ans x[1.0] as fixed while training.

Currently tensorflow does provide the option to make any variable non-trainable by using trainable=False or tf.stop_gradient(). However, these method will make the all element in x as non-trainable. My question is how to obtain this selectivity?

2

There are 2 answers

0
dominikroblek On

You can use tf.stop_gradient trick to prevent masked tf.Variable elements from training. For example:

x = tf.Variable(tf.zeros([2, 2]))
mask = tf.constant([[1, 0], [0, 1]], dtype=x.dtype)
x = mask * x + tf.stop_gradient((1 - mask) * x)
2
lejlot On

There is no selective lack of update as for now; however you can achieve this effect indirectly by specifing explicitely variables that should be updated. Both .minimize and all the gradient functions accept the list of variables you want to optimize over - just create a list omitting some of these, for example

v1 = tf.Variable( ... ) # we want to freeze it in one op 
v2 = tf.Variable( ... ) # we want to freeze it in another op
v3 = tf.Variable( ... ) # we always want to train this one
loss = ...
optimizer = tf.train.GradientDescentOptimizer(0.1)

op1 = optimizer.minimize(loss, 
      var_list=[v for v in tf.get_collection(tf.TRAINABLE_VARIABLES) if v != v1])

op2 = optimizer.minimize(loss, 
      var_list=[v for v in tf.get_collection(tf.TRAINABLE_VARIABLES) if v != v2])

and now you can call them whenever you want to train wrt. subset of variables. Note that this might require 2 separate optimizers if you are using Adam or some other method gathering statistics (and you will end up with separate statistics per optimizer!). However if there is just one set of frozen variables per training - everything will be straightforward with var_list.

However there is no way to fix training of the subset of the variable. Tensorflow treats variable as a single unit, always. You have to specify your computations in a different way to achieve this, one way is to:

  • create a binary mask M with 1's where you want to stop updates over X
  • create separate variable X', which is non-trainable, and tf.assign to it value of X
  • output X'*M + (1-M)*X

for example:

x = tf.Variable( ... )
xp= tf.Variable( ..., trainable=False)
m = tf.Constant( ... ) # mask
cp= tf.Assign(x, xp)
with tf.control_dependencies([cp]):
  x_frozen = m*xp + (1-m)*x

and you just use x_frozen instead of x. Note that we need control dependency as tf.assign can execute asynchronously, and here we want to make sure it always has the most up to date value of x.