Понимание K.ctc_decode - PullRequest
       38

Понимание K.ctc_decode

1 голос
/ 20 апреля 2020

Здесь - это тест Кераса для ctc_decode.

Вот мой слегка измененный пример:

def test_ctc_decode_greedy():
    def _remove_repeats(inds):
        is_not_repeat = np.insert(np.diff(inds).astype(np.bool), 0, True)
        return inds[is_not_repeat]

    def _remove_blanks(inds, n_classes):
        return inds[inds < (n_classes - 1)]

    def ctc_decode_np(y_pred, input_length):
        # Note:
        # Last element in alphabet treated as blank character
        # decoded_dense padded with -1

        n_samples = y_pred.shape[0]
        n_classes = y_pred.shape[-1]
        log_prob = np.zeros((n_samples, 1))
        decoded_dense = -np.ones_like(y_pred[..., 0])
        decoded_length = np.zeros((n_samples,), dtype=np.int)

        for i in range(n_samples):
            print('-'*60)
            # [n_time_steps, alphabet_size]
            prob = y_pred[i]

            length = input_length[i]

            decoded = np.argmax(prob[:length], axis=-1)

            print('decoded:', decoded)

            log_prob[i] = -np.sum(np.log(prob[np.arange(length), decoded]))

            decoded = _remove_repeats(decoded)

            print('decoded remove_repeats:', decoded)

            decoded = _remove_blanks(decoded, n_classes)

            print('decoded remove_blanks:', decoded)

            decoded_length[i] = len(decoded)
            decoded_dense[i, :len(decoded)] = decoded

        print('-' * 60)

        return decoded_dense[:, :np.max(decoded_length)], log_prob

    n_time_steps = 6
    alphabet_size = 4

    seq_len_0 = 4
    input_prob_matrix_0 = np.asarray(
        [[1.0, 0.0, 0.0, 0.0],  # t=0
         [0.0, 0.0, 0.4, 0.6],  # t=1
         [0.0, 0.0, 0.4, 0.6],  # t=2
         [0.0, 0.9, 0.1, 0.0],  # t=3
         [0.0, 0.0, 0.0, 0.0],  # t=4 (ignored)
         [0.0, 0.0, 0.0, 0.0]], # t=5 (ignored)
        dtype=np.float32)

    seq_len_1 = 5
    input_prob_matrix_1 = np.asarray(
        [[0.1, 0.9, 0.0, 0.0],  # t=0
         [0.0, 0.9, 0.1, 0.0],  # t=1
         [0.0, 0.0, 0.1, 0.9],  # t=2
         [0.0, 0.9, 0.1, 0.1],  # t=3
         [0.9, 0.1, 0.0, 0.0],  # t=4
         [0.0, 0.0, 0.0, 0.0]], # t=5 (ignored)
        dtype=np.float32)

    # [batch_size, max_chars_in_text, alphabet_size]
    inputs = np.array([input_prob_matrix_0, input_prob_matrix_1])
    print('inputs.shape', inputs.shape)

    # [batch_size, ]
    input_length = np.array([seq_len_0, seq_len_1], dtype=np.int32)
    print('input_length.shape', input_length.shape)

    decode_pred_np, log_prob_pred_np = ctc_decode_np(inputs, input_length)

    # max_decoded_text_lenght depends on batch, other shorter samples in batch padded with -1
    # [batch_size, max_decoded_text_lenght]
    print('decode_pred_np.shape', decode_pred_np.shape)
    print(decode_pred_np)

    # [batch_size, 1]
    print('log_prob_pred_np.shape', log_prob_pred_np.shape)
    print(log_prob_pred_np)

    inputs = K.variable(inputs)
    input_length = K.variable(input_length)
    decode_pred_tf, log_prob_pred_tf = K.ctc_decode(inputs, input_length, greedy=True)
    assert len(decode_pred_tf) == 1
    decode_pred = K.eval(decode_pred_tf[0])
    log_prob_pred = K.eval(log_prob_pred_tf)

    assert np.alltrue(decode_pred_np == decode_pred)
    assert np.allclose(log_prob_pred_np, log_prob_pred)

Вывод:

------------------------------------------------------------
keras version: 2.1.6
tensorflow version: 1.14.0
------------------------------------------------------------
inputs.shape (2, 6, 4)
input_length.shape (2,)
------------------------------------------------------------
decoded: [0 3 3 1]
decoded remove_repeats: [0 3 1]
decoded remove_blanks: [0 1]
------------------------------------------------------------
decoded: [1 1 3 1 0]
decoded remove_repeats: [1 3 1 0]
decoded remove_blanks: [1 1 0]
------------------------------------------------------------
decode_pred_np.shape (2, 3)
[[ 0.  1. -1.]
 [ 1.  1.  0.]]
log_prob_pred_np.shape (2, 1)
[[1.12701166]
 [0.52680272]]

Вопросы:

  1. Во время вывода, как мы получаем количество действительных временных шагов предсказания (например, некоторые последние строки игнорируются)? Как я понимаю при выводе, у нас есть матрица фиксированного размера в качестве выходных данных из сети, нужен ли нам другой выход для прогнозирования количества действительных временных шагов?

  2. Как я понимаю, log_prob отрицательно логарифмическая вероятность (чем меньше значение, тем лучше -> сеть более уверенно выводит), но поскольку количество временных шагов может быть разным, значения log_prob не сравнимы между выборками с разной длительностью временных шагов (потому что я не вижу какая-нибудь нормализация по длине)?

ОБНОВЛЕНИЕ:

Чтобы добавить больше контекста, вот мой более читаемый пример с реальным алфавитом и кодированием / декодированием текста здесь я использую input_length = n_time_steps:

def test_ctc_decode_greedy():
    def labels_to_text(labels, alphabet):
        chars = []
        for i in labels:
            chars.append(alphabet[i])
        text = ''.join(chars)
        return text

    def _remove_repeats(inds):
        is_not_repeat = np.insert(np.diff(inds).astype(np.bool), 0, True)
        return inds[is_not_repeat]

    def _remove_blanks(inds, n_classes):
        return inds[inds < (n_classes - 1)]

    def ctc_decode_np(y_pred, input_length):
        # Note:
        # Last element in alphabet treated as blank character
        # decoded_dense padded with -1

        n_samples = y_pred.shape[0]
        n_classes = y_pred.shape[-1] # len(alphabet) + 1
        log_prob = np.zeros((n_samples, 1))
        decoded_dense = -np.ones_like(y_pred[..., 0])

        print('n_samples', n_samples) #
        print('n_classes', n_classes) #
        print('log_prob.shape', log_prob.shape) #
        print('decoded_dense.shape', decoded_dense.shape) #

        for i in range(n_samples):
            print('-'*60)
            # [n_time_steps, alphabet_size]
            prob = y_pred[i]
            length = input_length[i]

            decoded = np.argmax(prob[:length], axis=-1)

            print('decoded:', decoded, labels_to_text(decoded, alphabet)) #

            print('prob[np.arange(length), decoded] :', prob[np.arange(length), decoded]) #
            print('np.log(prob[np.arange(length), decoded]) :', np.log(prob[np.arange(length), decoded])) #

            log_prob[i] = -np.sum(np.log(prob[np.arange(length), decoded]))
            print('log_prob[i]', log_prob[i]) #

            decoded = _remove_repeats(decoded)

            print('decoded: remove_repeats:', decoded, labels_to_text(decoded, alphabet))  #

            decoded = _remove_blanks(decoded, n_classes)

            print('decoded: remove_blanks:', decoded, labels_to_text(decoded, alphabet))  #

            decoded_dense[i, :len(decoded)] = decoded

        print('-' * 60)

        return decoded_dense, log_prob

    n_time_steps = 8
    alphabet = ['h', 'e', 'l', 'o', '-']
    alphabet_size = len(alphabet)

    seq_len_0 = n_time_steps
    # 'hhel-lo-' after decoding should be 'hello'
    input_prob_matrix_0 = np.asarray(
        [[1.0, 0.0, 0.0, 0.0, 0.0],  # t=0 # 'h'
         [1.0, 0.0, 0.0, 0.0, 0.0],  # t=1 # 'h'
         [0.0, 1.0, 0.0, 0.0, 0.0],  # t=2 # 'e'
         [0.0, 0.0, 1.0, 0.0, 0.0],  # t=3 # 'l'
         [0.0, 0.0, 0.0, 0.0, 1.0],  # t=4 # '-'
         [0.0, 0.0, 1.0, 0.0, 0.0],  # t=5 # 'l'
         [0.0, 0.0, 0.0, 1.0, 0.0],  # t=6 # 'o'
         [0.0, 0.0, 0.0, 0.0, 1.0]], # t=7 # '-'
        dtype=np.float32)

    # [batch_size, n_time_steps, alphabet_size]
    inputs = np.array([input_prob_matrix_0])
    print('inputs.shape', inputs.shape) #

    # [batch_size, ]
    input_length = np.array([seq_len_0], dtype=np.int32)
    print('input_length.shape', input_length.shape) #

    decode_pred_np, log_prob_pred_np = ctc_decode_np(inputs, input_length)

    # [batch_size, n_time_steps]
    print('decode_pred_np.shape', decode_pred_np.shape) #
    print(decode_pred_np)

    # [batch_size, 1]
    print('log_prob_pred_np.shape', log_prob_pred_np.shape) #
    print(log_prob_pred_np)

Вывод:

------------------------------------------------------------
inputs.shape (1, 8, 5)
input_length.shape (1,)
n_samples 1
n_classes 5
log_prob.shape (1, 1)
decoded_dense.shape (1, 8)
------------------------------------------------------------
decoded: [0 0 1 2 4 2 3 4] hhel-lo-
prob[np.arange(length), decoded] : [1. 1. 1. 1. 1. 1. 1. 1.]
np.log(prob[np.arange(length), decoded]) : [0. 0. 0. 0. 0. 0. 0. 0.]
log_prob[i] [-0.]
decoded: remove_repeats: [0 1 2 4 2 3 4] hel-lo-
decoded: remove_blanks: [0 1 2 2 3] hello
------------------------------------------------------------
decode_pred_np.shape (1, 8)
[[ 0.  1.  2.  2.  3. -1. -1. -1.]]
log_prob_pred_np.shape (1, 1)
[[-0.]]
...