tf.gather выходит за границы при использовании пользовательской функции softmax_loss, даже если она не должна - PullRequest
3 голосов
/ 13 марта 2019

Я использую небольшую пользовательскую функцию внутри tf.contrib.seq2seq.sequence_loss(softmax_loss_function=[...]) в качестве пользовательской функции sofmax_loss_function:

    def reduced_softmax_loss(self, labels, logits):
        top_logits, indices = tf.nn.top_k(logits, self.nb_top_classes, sorted=False)
        top_labels = tf.gather(labels, indices)

        return tf.nn.softmax_cross_entropy_with_logits_v2(labels=top_labels,
                                                          logits=top_logits)

Но даже если метки и логиты должны иметь одинаковое измерение, после выполнения возвращается и InvalidArgumentError:

indices[1500,1] = 2158 is not in [0, 1600) с числами, меняющимися из-за моего случайного семени.

Есть ли другая функция, такая как tf.gather, которую я мог бы использовать вместо этого? Или возвращаемое значение в ложной форме?

Все работает нормально, если я передаю обычные функции Tensorflow.

Заранее спасибо!

1 Ответ

0 голосов
/ 14 марта 2019

Трудно сказать, что происходит, просто посмотрев на ваш код, но я не думаю, что код, который вы написали, делает то, что вы хотите, чтобы он делал. Операция tf.gather ожидает ввода индексов, где каждое скалярное значение индексируется во внешнем измерении первого аргумента, но здесь вывод top_k пытается индексировать как строки, так и столбцы, что приводит к ошибкам вне границ.

...