Я использую небольшую пользовательскую функцию внутри tf.contrib.seq2seq.sequence_loss(softmax_loss_function=[...])
в качестве пользовательской функции sofmax_loss_function:
def reduced_softmax_loss(self, labels, logits):
top_logits, indices = tf.nn.top_k(logits, self.nb_top_classes, sorted=False)
top_labels = tf.gather(labels, indices)
return tf.nn.softmax_cross_entropy_with_logits_v2(labels=top_labels,
logits=top_logits)
Но даже если метки и логиты должны иметь одинаковое измерение, после выполнения возвращается и InvalidArgumentError
:
indices[1500,1] = 2158 is not in [0, 1600)
с числами, меняющимися из-за моего случайного семени.
Есть ли другая функция, такая как tf.gather
, которую я мог бы использовать вместо этого? Или возвращаемое значение в ложной форме?
Все работает нормально, если я передаю обычные функции Tensorflow.
Заранее спасибо!