Что именно tenensflow.gather () делает? - PullRequest
0 голосов
/ 28 июня 2019

Я видел код потери триплета, который содержит функцию tf.gather (). Что делает эта функция?

Я прошёл официальный сайт tenorflow для определения, но всё ещё не могу его получить.

def margin_triplet_loss(y_true, y_pred, margin, batch_size):
    anchor = tf.gather(y_pred, tf.range(0, batch_size, 3))
    positive = tf.gather(y_pred, tf.range(1, batch_size, 3))
    negative = tf.gather(y_pred, tf.range(2, batch_size, 3))

    loss = K.maximum(margin
                 + K.sum(K.square(anchor-positive), axis=1)
                 - K.sum(K.square(anchor-negative), axis=1),
                 0.0)
    return K.mean(loss)

1 Ответ

1 голос
/ 28 июня 2019

tf.gather - это функция для индексации массива. Вы собираете элементы, которые вы указываете в качестве аргумента index. Это не является естественным для тензорных тензорных течений.

tf.gather (y_pred, tf.range (0, batch_size, 3)) эквивалентно по numpy y_pred [0: batch_size: 3], что означает, что вы возвращаете каждый третий элемент, начиная с первого.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...