How to improve the word rnn accuracy in tensorflow?

590 views Asked by At

I'm working on a title auto generate project with tensorflow seq2seq.rnn_decoder.

My training set is a big set of titles, each title is independent of each other and is not relevant.

I had try two data format for training:

F1. Use the fixed seq length in batch, and replace ‘\n’ to ‘<eos>’, and ‘<eos>’ index is 1, which training batch is like: [2,3,4,5,8,9,1,2,3,4], [88,99,11,90,1,5,6,7,8,10]
F2. Use Variable seq length in batch, and add PAD 0 to keep the fixed length, which training batch is like: [2,3,4,5,8,9,0,0,0,0], [2,3,4,88,99,90,11,0,0,0]

Then I do the test in a small set which has 10,000 titles, but the results make me confused.

F1 is make a good prediction in single word, like this:

iphone predict 6
samsung predict galaxy
case predict cover

F2 is make a good prediction in a long sentence if the input is start from the first word of sentence, many times the prediction is almost equals the original sentence.

But, if the starting word is from the middle(or near end) of the sentence, F2’s prediction is very very bad, just like the random result.

Is this situation related to the hidden state ?

In the training phase, I reset the hidden state to 0 when a new epoch begin, So all batch in epoch will be use the same hidden state, I suspect that this is not a good practice, because every sentences are actually independent, should it’s can share the same hidden state in training ?

In the infer phase, the init hidden state is 0, & updated when feed a word. (reset to 0 when clear input)

So my question is why F2’s prediction is bad when starting word is from the middle (or near end) of the sentence ? And what is the right way to update hidden state in my project ?

1

There are 1 answers

1
Lukasz Kaiser On

I'm not sure I understand your setting 100% correctly, but I think what you see happening is expected and has to do with the handling of the hidden state.

Let's first look at what you see in F2. Since you reset your hidden state every time, the network only sees a 0-state at the beginning of a whole title, right? So, during training, it probably never has a 0-state except when starting the sequence. When you try to decode from the middle, you start from 0-state in a position it has never seen like this during training, so it fails.

In F1, you also reset the state, but since you're not padding, the 0-state appears more randomly during training -- sometimes at the beginning, sometimes in the middle of the title. And the network learns to cope with this.