I'm trying to understand how CTC implementation works in TensorFlow. I've wrote a quick example just to test CTC function, but for some reason I'm gettign inf
for some target/input values and I'm sure why is that happing!?
Code:
import tensorflow as tf
import numpy as np
# https://github.com/philipperemy/tensorflow-ctc-speech-recognition/blob/master/utils.py
def sparse_tuple_from(sequences, dtype=np.int32):
"""Create a sparse representention of x.
Args:
sequences: a list of lists of type dtype where each element is a sequence
Returns:
A tuple with (indices, values, shape)
"""
indices = []
values = []
for n, seq in enumerate(sequences):
indices.extend(zip([n] * len(seq), range(len(seq))))
values.extend(seq)
indices = np.asarray(indices, dtype=np.int64)
values = np.asarray(values, dtype=dtype)
shape = np.asarray([len(sequences), np.asarray(indices).max(0)[1] + 1], dtype=np.int64)
return indices, values, shape
batch_size = 1
seq_length = 2
n_labels = 2
seq_len = tf.placeholder(tf.int32, [None])
targets = tf.sparse_placeholder(tf.int32)
logits = tf.constant(np.random.random((batch_size, seq_length, n_labels+1)),dtype=tf.float32) # +1 for the blank label
loss = tf.reduce_mean(tf.nn.ctc_loss(targets, logits, seq_len, time_major = False))
with tf.Session() as sess:
for it in range(10):
rand_target = np.random.randint(n_labels, size=(seq_length))
sample_target = sparse_tuple_from([rand_target])
logitsval = sess.run(logits)
lossval = sess.run(loss, feed_dict={seq_len: [seq_length], targets: sample_target})
print('******* Iter: %d *******'%it)
print('logits:', logitsval)
print('rand_target:', rand_target)
print('rand_sparse_target:', sample_target)
print('loss:', lossval)
print()
Sample Output:
******* Iter: 0 *******
logits: [[[ 0.10151503 0.88581538 0.56466645]
[ 0.76043415 0.52718711 0.01166286]]]
rand_target: [0 1]
rand_sparse_target: (array([[0, 0],
[0, 1]]), array([0, 1], dtype=int32), array([1, 2]))
loss: 2.61521
******* Iter: 1 *******
logits: [[[ 0.10151503 0.88581538 0.56466645]
[ 0.76043415 0.52718711 0.01166286]]]
rand_target: [1 1]
rand_sparse_target: (array([[0, 0],
[0, 1]]), array([1, 1], dtype=int32), array([1, 2]))
loss: inf
******* Iter: 2 *******
logits: [[[ 0.10151503 0.88581538 0.56466645]
[ 0.76043415 0.52718711 0.01166286]]]
rand_target: [0 1]
rand_sparse_target: (array([[0, 0],
[0, 1]]), array([0, 1], dtype=int32), array([1, 2]))
loss: 2.61521
******* Iter: 3 *******
logits: [[[ 0.10151503 0.88581538 0.56466645]
[ 0.76043415 0.52718711 0.01166286]]]
rand_target: [1 0]
rand_sparse_target: (array([[0, 0],
[0, 1]]), array([1, 0], dtype=int32), array([1, 2]))
loss: 1.59766
******* Iter: 4 *******
logits: [[[ 0.10151503 0.88581538 0.56466645]
[ 0.76043415 0.52718711 0.01166286]]]
rand_target: [0 0]
rand_sparse_target: (array([[0, 0],
[0, 1]]), array([0, 0], dtype=int32), array([1, 2]))
loss: inf
******* Iter: 5 *******
logits: [[[ 0.10151503 0.88581538 0.56466645]
[ 0.76043415 0.52718711 0.01166286]]]
rand_target: [0 1]
rand_sparse_target: (array([[0, 0],
[0, 1]]), array([0, 1], dtype=int32), array([1, 2]))
loss: 2.61521
******* Iter: 6 *******
logits: [[[ 0.10151503 0.88581538 0.56466645]
[ 0.76043415 0.52718711 0.01166286]]]
rand_target: [1 0]
rand_sparse_target: (array([[0, 0],
[0, 1]]), array([1, 0], dtype=int32), array([1, 2]))
loss: 1.59766
******* Iter: 7 *******
logits: [[[ 0.10151503 0.88581538 0.56466645]
[ 0.76043415 0.52718711 0.01166286]]]
rand_target: [1 1]
rand_sparse_target: (array([[0, 0],
[0, 1]]), array([1, 1], dtype=int32), array([1, 2]))
loss: inf
******* Iter: 8 *******
logits: [[[ 0.10151503 0.88581538 0.56466645]
[ 0.76043415 0.52718711 0.01166286]]]
rand_target: [0 1]
rand_sparse_target: (array([[0, 0],
[0, 1]]), array([0, 1], dtype=int32), array([1, 2]))
loss: 2.61521
******* Iter: 9 *******
logits: [[[ 0.10151503 0.88581538 0.56466645]
[ 0.76043415 0.52718711 0.01166286]]]
rand_target: [0 0]
rand_sparse_target: (array([[0, 0],
[0, 1]]), array([0, 0], dtype=int32), array([1, 2]))
loss: inf
any idea what am i missing there!?
Look closely at your input texts (rand_target), I'm sure you see some simple pattern which correlates with the inf loss value ;-)
A short explanation of what is happening: CTC encodes text by allowing each character to be repeated and it also allows a non-character marker (called "CTC blank label") to be inserted between characters. Undoing this encoding (or decoding) then simply means throwing away repeated characters and then throwing away all blanks. To give some examples ("..." corresponds to text, '...' to encodings and '-' to the blank label):
Now we know enough to see why some of your samples fail:
You can also imagine the encoding as a state machine - see illustration below. The text "11" can be represented by all possible paths starting at a start state (two leftmost states) and ending at a final state (two rightmost states). As you can see, the shortest possible path is '1-1'.
To conclude, you have to account for at least one additional blank to be inserted for each repeated character in the input text. Maybe this article helps in understanding CTC: https://towardsdatascience.com/3797e43a86c