Я пишу код для задачи speech2text на основе потери CT C. Как вы знаете, мы должны определить пустой индекс для тихой части речи. Моя модель так проста:
def get_model(input_dim, output_dim,
rnn_units=30) -> Model:
with tf.device('/cpu:0'):
input_tensor = layers.Input([None, input_dim])
x = layers.Lambda(k.expand_dims,
arguments=dict(axis=-1))(input_tensor)
x = layers.Conv2D(filters=32,
kernel_size=[11, 4],
strides=[2, 2],
padding='same',
use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)
x = layers.Conv2D(filters=32,
kernel_size=[11, 2],
strides=[1, 2],
padding='same',
use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)
x = layers.Reshape([-1, input_dim // 4 * 32])(x)
recurrent = layers.LSTM(units=rnn_units,
activation='tanh',
recurrent_activation='sigmoid',
use_bias=True,
return_sequences=True,
)
x = layers.Bidirectional(recurrent,
merge_mode='concat')(x)
x = layers.TimeDistributed(layers.Dense(units=rnn_units * 2))(x)
x = layers.ReLU()(x)
output_tensor = layers.Dense(units=output_dim)(x)
# output_tensor = layers.Lambda(lambda y: softmax(y, axis=-1))(output_tensor)
model = Model(input_tensor, output_tensor)
return model
Затем я использую CT C функцию потерь следующим образом
def get_loss() -> Callable:
def get_length(tensor):
lengths = tf.math.reduce_sum(tf.ones_like(tensor), 1)
return tf.cast(lengths, tf.int32)
def ctc_loss(labels, logits):
label_length = get_length(labels)
logit_length = get_length(tf.math.reduce_max(logits, 2))
return tf.reduce_mean(tf.nn.ctc_loss(labels, logits, label_length, logit_length,
logits_time_major=False, blank_index=-1))
return ctc_loss
И затем оптимизирую функцию потерь следующим образом:
y = Input(name='y', shape=[None], dtype='int32')
model.compile(RMSprop(1e-4), loss=get_loss(), target_tensors=[y])
model.fit(dataset, validation_data=dev_dataset, callbacks=[checkpointer], **kwargs)
Через несколько эпох (с любой скоростью обучения) я получил следующие результаты:
[33 33 33 33 33 33 33 33 33 33 33 33 33 33 33]
Число 33 эквивалентно пустому индексу. В любое время, когда я обучаю сеть, вывод для любого входа таков! В чем проблема? : - \