How to Initialize LSTMCell with tuple

828 views Asked by At

I recently upgraded my tesnorflow from Rev8 to Rev12. In Rev8 the default "state_is_tuple" flag in rnn_cell.LSTMCell is set to False, so I initialized my LSTM Cell with an list, see code below.

#model definition  
lstm_cell = rnn_cell.LSTMCell(self.config.hidden_dim)
outputs, states = tf.nn.rnn(lstm_cell, data, initial_state=self.init_state)


#init_state place holder and feed_dict
def add_placeholders(self):
     self.init_state = tf.placeholder("float", [None, self.cell_size])

def get_feed_dict(self, data, label):
    feed_dict = {self.input_data: data,
             self.input_label: reg_label,
             self.init_state: np.zeros((self.config.batch_size, self.cell_size))}
    return feed_dict

In Rev12, the default "state_is_tuple" flag is set to True, in order to make my old code work I had to explicitly turn the flag to False. However, now I got an warning from tensorflow saying:

"Using a concatenated state is slower and will soon be deprecated. Use state_is_tuple=True"

I tried to initialize LSTM cell with a tuple by changing the placeholder definition for self.init_state to the following:

self.init_state = tf.placeholder("float", (None, self.cell_size))

but now I got an error message saying:

"'Tensor' object is not iterable"

Does anyone know how to make this work?

1

There are 1 answers

0
martianwars On BEST ANSWER

Feeding a "zero state" to an LSTM is much simpler now using cell.zero_state. You do not need to explicitely define the initial state as a placeholder. Define it as a tensor instead and feed it if required. This is how it works,

lstm_cell = rnn_cell.LSTMCell(self.config.hidden_dim)
self.initial_state = lstm_cell.zero_state(self.batch_size, dtype=tf.float32)
outputs, states = tf.nn.rnn(lstm_cell, data, initial_state=self.init_state)

If you wish to feed some other value as the initial state, Let's say next_state = states[-1] for instance, calculate it in your session and pass it in the feed_dict like -

feed_dict[self.initial_state] = next_state

In the context of your question, lstm_cell.zero_state() should suffice.


Unrelated, but remember that you can pass both Tensors and Placeholders in the feed dictionary! That's how self.initial_state is working in the example above. Have a look at the PTB Tutorial for a working example.