LSTM timesteps in Sonnet

484 views Asked by At

I'm currently trying to learn Sonnet.

My network (incomplete, the question is based on this):

class Model(snt.AbstractModule):

    def __init__(self, name="LSTMNetwork"):
        super(Model, self).__init__(name=name)
        with self._enter_variable_scope():
            self.l1 = snt.LSTM(100)
            self.l2 = snt.LSTM(100)
            self.out = snt.LSTM(10)

    def _build(self, inputs):

        # 'inputs' is of shape (batch_size, input_length)
        # I need it to be of shape (batch_size, sequence_length, input_length)

        l1_state = self.l1.initialize_state(np.shape(inputs)[0]) # init with batch_size
        l2_state = self.l2.initialize_state(np.shape(inputs)[0]) # init with batch_size
        out_state = self.out.initialize_state(np.shape(inputs)[0])

        l1_out, l1_state = self.l1(inputs, l1_state)
        l1_out = tf.tanh(l1_out)
        l2_out, l2_state = self.l2(l1_out, l2_state)
        l2_out = tf.tanh(l2_out)
        output, out_state = self.out(l2_out, out_state)
        output = tf.sigmoid(output)

        return output, out_state

In other frameworks (eg. Keras), LSTM inputs are of the form (batch_size, sequence_length, input_length).

However, the Sonnet documentation states that the input to Sonnet's LSTM is of the form (batch_size, input_length).

How do I use them for sequential input?

So far, I've tried using a for loop inside _build, iterating over each timestep, but that gives seemingly random outputs.

I've tried the same architecture in Keras, which runs without any issues.

I'm executing in eager mode, using GradientTape for training.

1

There are 1 answers

0
Malcolm Reynolds On

We generally wrote the RNNs in Sonnet to work on a single timestep basis, as for Reinforcement Learning you often need to run one timestep to pick an action, and without that action you can't get the next observation (and the next input timestep) from the environment. It's easy to unroll a single timestep module over a sequence using tf.nn.dynamic_rnn (see below). We also have a wrapper which takes care of composing several RNN cores per timestep, which I believe is what you're looking to do. This has the advantage that the DeepCore object supports the start state methods required for dynamic_rnn, so it's API compatibe with LSTM or any other single-timestep module.

What you want to do should be achievable like this:

# Create a single-timestep RNN module by composing recurrent modules and
# non-recurrent ops.
model = snt.DeepRNN([
    snt.LSTM(100),
    tf.tanh,
    snt.LSTM(100),
    tf.tanh,
    snt.LSTM(100),
    tf.sigmoid
], skip_connections=False)

batch_size = 2
sequence_length = 3
input_size = 4

single_timestep_input = tf.random_uniform([batch_size, input_size])
sequence_input = tf.random_uniform([batch_size, sequence_length, input_size])

# Run the module on a single timestep
single_timestep_output, next_state = model(
    single_timestep_input, model.initial_state(batch_size=batch_size))

# Unroll the module on a full sequence
sequence_output, final_state = tf.nn.dynamic_rnn(
    core, sequence_input, dtype=tf.float32)

A few things to note - if you haven't already please have a look at the RNN example in the repository, as this shows a full graph mode training procedure setup around a fairly similar model.

Secondly, if you do end up needing to implement a more complex module that DeepRNN allows for, it's important to thread the recurrent state in and out of the module. In your example you're making the input state internally, and l1_state and l2_state as output are effectively discarded, so this can't be properly trained. If DeepRNN wasn't available, your model would look like this:

class LSTMNetwork(snt.RNNCore):  # Note we inherit from the RNN-specific subclass
  def __init__(self, name="LSTMNetwork"):
    super(Model, self).__init__(name=name)
    with self._enter_variable_scope():
      self.l1 = snt.LSTM(100)
      self.l2 = snt.LSTM(100)
      self.out = snt.LSTM(10)

  def initial_state(self, batch_size):
    return (self.l1.initial_state(batch_size),
            self.l2.initial_state(batch_size),
            self.out.initial_state(batch_size))

  def _build(self, inputs, prev_state):

    # separate the components of prev_state
    l1_prev_state, l2_prev_state, out_prev_state = prev_state

    l1_out, l1_next_state = self.l1(inputs, l1_prev_state)
    l1_out = tf.tanh(l1_out)
    l2_out, l2_next_state = self.l2(l1_out, l2_prev_state)
    l2_out = tf.tanh(l2_out)
    output, out_next_state = self.out(l2_out, out_prev_state)

    # Output state of LSTMNetwork contains the output states of inner modules.
    full_output_state = (l1_next_state, l2_next_state, out_next_state)

    return tf.sigmoid(output), full_output_state

Finally, if you're using eager mode I would strongly encourage you to have a look at Sonnet 2 - it's a complete rewrite for TF 2 / Eager mode. It's not backwards compatible, but all the same kinds of module compositions are possible. Sonnet 1 was written primarily for Graph mode TF, and while it does work with Eager mode you'll probably encounter some things that aren't very convenient.

We worked closely with the TensorFlow team to make sure that TF 2 & Sonnet 2 work nicely together, so please have a look: (https://github.com/deepmind/sonnet/tree/v2). Sonnet 2 should be considered alpha, and is being actively developed, so we don't have loads of examples yet, but more will be added in the near future.