Понимание того, как работает реализация TF для CTC - PullRequest
0 голосов
/ 27 сентября 2018

Я пытаюсь понять, как работает реализация CTC в TensorFlow.Я написал быстрый пример просто для проверки функции CTC, но по какой-то причине я получаю inf для некоторых целевых / входных значений, и я уверен, почему это происходит!?

Код:

import tensorflow as tf
import numpy as np

# https://github.com/philipperemy/tensorflow-ctc-speech-recognition/blob/master/utils.py
def sparse_tuple_from(sequences, dtype=np.int32):
    """Create a sparse representention of x.
    Args:
        sequences: a list of lists of type dtype where each element is a sequence
    Returns:
        A tuple with (indices, values, shape)
    """
    indices = []
    values = []

    for n, seq in enumerate(sequences):
        indices.extend(zip([n] * len(seq), range(len(seq))))
        values.extend(seq)

    indices = np.asarray(indices, dtype=np.int64)
    values = np.asarray(values, dtype=dtype)
    shape = np.asarray([len(sequences), np.asarray(indices).max(0)[1] + 1], dtype=np.int64)

    return indices, values, shape

batch_size = 1
seq_length = 2
n_labels = 2

seq_len = tf.placeholder(tf.int32, [None])
targets = tf.sparse_placeholder(tf.int32)
logits = tf.constant(np.random.random((batch_size, seq_length, n_labels+1)),dtype=tf.float32) # +1 for the blank label
loss = tf.reduce_mean(tf.nn.ctc_loss(targets, logits, seq_len, time_major = False))


with tf.Session() as sess:
    for it in range(10):
        rand_target = np.random.randint(n_labels, size=(seq_length))
        sample_target = sparse_tuple_from([rand_target])

        logitsval = sess.run(logits)
        lossval = sess.run(loss, feed_dict={seq_len: [seq_length], targets: sample_target})
        print('******* Iter: %d *******'%it)
        print('logits:', logitsval)
        print('rand_target:', rand_target)
        print('rand_sparse_target:', sample_target)
        print('loss:', lossval)
        print()

Пример вывода:

******* Iter: 0 *******
logits: [[[ 0.10151503  0.88581538  0.56466645]
  [ 0.76043415  0.52718711  0.01166286]]]
rand_target: [0 1]
rand_sparse_target: (array([[0, 0],
       [0, 1]]), array([0, 1], dtype=int32), array([1, 2]))
loss: 2.61521

******* Iter: 1 *******
logits: [[[ 0.10151503  0.88581538  0.56466645]
  [ 0.76043415  0.52718711  0.01166286]]]
rand_target: [1 1]
rand_sparse_target: (array([[0, 0],
       [0, 1]]), array([1, 1], dtype=int32), array([1, 2]))
loss: inf

******* Iter: 2 *******
logits: [[[ 0.10151503  0.88581538  0.56466645]
  [ 0.76043415  0.52718711  0.01166286]]]
rand_target: [0 1]
rand_sparse_target: (array([[0, 0],
       [0, 1]]), array([0, 1], dtype=int32), array([1, 2]))
loss: 2.61521

******* Iter: 3 *******
logits: [[[ 0.10151503  0.88581538  0.56466645]
  [ 0.76043415  0.52718711  0.01166286]]]
rand_target: [1 0]
rand_sparse_target: (array([[0, 0],
       [0, 1]]), array([1, 0], dtype=int32), array([1, 2]))
loss: 1.59766

******* Iter: 4 *******
logits: [[[ 0.10151503  0.88581538  0.56466645]
  [ 0.76043415  0.52718711  0.01166286]]]
rand_target: [0 0]
rand_sparse_target: (array([[0, 0],
       [0, 1]]), array([0, 0], dtype=int32), array([1, 2]))
loss: inf

******* Iter: 5 *******
logits: [[[ 0.10151503  0.88581538  0.56466645]
  [ 0.76043415  0.52718711  0.01166286]]]
rand_target: [0 1]
rand_sparse_target: (array([[0, 0],
       [0, 1]]), array([0, 1], dtype=int32), array([1, 2]))
loss: 2.61521

******* Iter: 6 *******
logits: [[[ 0.10151503  0.88581538  0.56466645]
  [ 0.76043415  0.52718711  0.01166286]]]
rand_target: [1 0]
rand_sparse_target: (array([[0, 0],
       [0, 1]]), array([1, 0], dtype=int32), array([1, 2]))
loss: 1.59766

******* Iter: 7 *******
logits: [[[ 0.10151503  0.88581538  0.56466645]
  [ 0.76043415  0.52718711  0.01166286]]]
rand_target: [1 1]
rand_sparse_target: (array([[0, 0],
       [0, 1]]), array([1, 1], dtype=int32), array([1, 2]))
loss: inf

******* Iter: 8 *******
logits: [[[ 0.10151503  0.88581538  0.56466645]
  [ 0.76043415  0.52718711  0.01166286]]]
rand_target: [0 1]
rand_sparse_target: (array([[0, 0],
       [0, 1]]), array([0, 1], dtype=int32), array([1, 2]))
loss: 2.61521

******* Iter: 9 *******
logits: [[[ 0.10151503  0.88581538  0.56466645]
  [ 0.76043415  0.52718711  0.01166286]]]
rand_target: [0 0]
rand_sparse_target: (array([[0, 0],
       [0, 1]]), array([0, 0], dtype=int32), array([1, 2]))
loss: inf

Есть идеи, что мне там не хватает!?

1 Ответ

0 голосов
/ 30 сентября 2018

Посмотрите внимательно на ваши входные тексты (rand_target), я уверен, что вы видите какой-то простой шаблон, который коррелирует со значением inf inf; -)

Краткое объяснение того, что происходит: CTC кодирует текст с помощьюпозволяет повторять каждый символ, а также позволяет вставлять не символьный маркер (называемый «пустая метка CTC») между символами.Отмена этого кодирования (или декодирования) просто означает отбрасывание повторяющихся символов, а затем отбрасывание всех пробелов.Чтобы привести несколько примеров («...» соответствует тексту, «...» - кодировкам, а «-» - пустой метке):

  • «to» -> «tttooo» или'to' или 't-oo', или 'to', и так далее ...
  • "too" -> 'to-o', или 'tttoo --- oo', или '---too-- ', но НЕ «слишком» (подумайте о том, как будет выглядеть декодированный «слишком»)

Теперь мы знаем достаточно, чтобы понять, почему некоторые из ваших сэмплов терпят неудачу:

  • длина вашего входного текста составляет 2
  • длина кодировки составляет 2
  • , если вводимый символ повторяется (например, '11' или в виде списка Python: [1, 1]), тогда единственный способ закодировать это - поместить пробел между ними (подумайте о том, как расшифровывать «11» и «1-1»).Но тогда кодирование будет иметь длину 3.
  • , поэтому невозможно кодировать тексты длины 2 с повторяющимся символом в кодировку длины 2, поэтому реализация потери TF возвращает inf

Вы также можете представить кодировку как конечный автомат - см. Иллюстрацию ниже.Текст «11» может быть представлен всеми возможными путями, начиная с начального состояния (два крайних левых состояния) и заканчивая конечным состоянием (два крайних правых состояния).Как видите, кратчайший путь - «1-1».

enter image description here

В заключение вам необходимо учесть как минимум один дополнительный пробелбыть вставленным для каждого повторяющегося символа во входном тексте.Может быть, эта статья поможет понять CTC: https://towardsdatascience.com/3797e43a86c

...