LSTM Followed by Mean Pooling (TensorFlow)

1.4k views Asked by At

I am aware that there is a similar topic at LSTM Followed by Mean Pooling, but that is about Keras and I work in pure TensorFlow.

I have an LSTM network where the recurrence is handled by:

outputs, final_state = tf.nn.dynamic_rnn(cell,
                                         embed,
                                         sequence_length=seq_lengths,
                                         initial_state=initial_state)

where I pass the correct sequence lengths for each sample (padding by zeros). In any case, outputs contains irrelevant outputs since some samples produce longer outputs than others, based on sequence lengths.

Right now I'm extracting the last relevant output by means of the following method:

def extract_axis_1(data, ind):
    """
    Get specified elements along the first axis of tensor.
    :param data: Tensorflow tensor that will be subsetted.
    :param ind: Indices to take (one for each element along axis 0 of data).
    :return: Subsetted tensor.
    """

    batch_range = tf.range(tf.shape(data)[0])
    indices = tf.stack([batch_range, ind], axis=1)
    res = tf.reduce_mean(tf.gather_nd(data, indices), axis=0)

where I pass sequence_length - 1 as indices. In reference to the last topic, I would like to select all relevant outputs followed by average pooling, instead of just the last one.

Now, I tried passing nested lists as indeces to extract_axis_1 but tf.stack does not accept this.

Any solution directions for this?

1

There are 1 answers

8
Giuseppe Marra On

You can exploit the weight parameter of the tf.contrib.seq2seq.sequence_loss function.

From the documentation:

weights: A Tensor of shape [batch_size, sequence_length] and dtype float. weights constitutes the weighting of each prediction in the sequence. When using weights as masking, set all valid timesteps to 1 and all padded timesteps to 0, e.g. a mask returned by tf.sequence_mask.

You need to compute a binary mask that distinguish between your valid outputs and invalid ones. Then you can just provide this mask to the weights parameter of the loss function (probably, you will want to use a loss like this one); the function will not consider the outputs with a 0 weight in the computation of the loss.

If you can't/don't need to use a sequence loss you can do exactly the same thing manually. You compute a binarymask and then multiply your outputs by this mask and provide these as inputs to your fully connected layer.