Я видел код потери триплета, который содержит функцию tf.gather (). Что делает эта функция?
Я прошёл официальный сайт tenorflow для определения, но всё ещё не могу его получить.
def margin_triplet_loss(y_true, y_pred, margin, batch_size):
anchor = tf.gather(y_pred, tf.range(0, batch_size, 3))
positive = tf.gather(y_pred, tf.range(1, batch_size, 3))
negative = tf.gather(y_pred, tf.range(2, batch_size, 3))
loss = K.maximum(margin
+ K.sum(K.square(anchor-positive), axis=1)
- K.sum(K.square(anchor-negative), axis=1),
0.0)
return K.mean(loss)