Predicting the Sinus Functions with RNNs

214 views Asked by At

As a toy problem I am trying to generate sequences from a certain function, e.g. sinus. For that I am using LSTMs in Tensorflow. I think I understood the first LSTM tutorial from the official website, but did not make use of the second one (seq2seq) - for one, because I do not know if it matches the problem and because I want to solve it first by hand. I also looked at some other posts, like generating sequences of characters.

For training, I feed the network sequences of the sinus function of length n, with the target output being the same sequence shifted right by one. Between these I calculate the L2 loss and use this for training.

inputs = tf.unpack(self.input_data, num=num_steps, axis=1) 
outputs, state = tf.nn.rnn(self.cell, inputs, initial_state=self.initial_state)

On outputs I apply a fully connected layer to get one output for each batch:

self.output = (tf.add(tf.matmul(output, output_w1),output_b1))

The training works well and the error approaches 0. Then I want to recreate the sinus function with the learned parameters. For this I switch the sequence length to 1 and feed in one value at a time. I give the network some initilization in the form of some batches (feed in the next values from the training data), and then use the output of each step as the new input. However, this does not work at all. Right after the initialization the output starts to differ from the sinus function and converges to some fixed value. But interestingly, when I take the network output for the plot while feeding in the training data, the sinus is reconstructed accurately. I can even feed in every second step the created output to get the sinus function, it just seems to need this "stabilization".

My first question is, is this procedure right at all? Training with this sequence loss and for inference feeding in one value at a time. Even if that worked, it is super slow - is there any improvement to this? And of course: Why is it so bad, what am I doing wrong?

Any help would be greatly appreciated! Thanks :)

0

There are 0 answers