RNN Cell not present in tf.get_collection

92 views Asked by At

RNN Cells are not showing up when using tf.get_collection(). What am I missing?

import tensorflow as tf
print(tf.__version__)

rnn_cell = tf.nn.rnn_cell.LSTMCell(16)
print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

other_var = tf.Variable(0)
print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

prints out

0.12.0
[]
[<tensorflow.python.ops.variables.Variable object at 0x0000027961250B70>]

Windows 10, Python 3.5

1

There are 1 answers

0
martianwars On BEST ANSWER

You have not run a __call__ on the LSTMCell which is why you don't see your variables. Try this instead (I'm assuming batch_size=10 and rnn_size=16)

import tensorflow as tf
print(tf.__version__)

rnn_cell = tf.nn.rnn_cell.LSTMCell(16)
a = tf.placeholder(tf.float32, [10, 16])
zero = rnn_cell.zero_state(10,tf.float32)
# The variables are created in the following __call__
b = rnn_cell(a, zero)
print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

other_var = tf.Variable(0)
print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))