Я немного застрял здесь, поэтому вижу несколько советов. У меня есть triple_loss
функция и модель для генерации вложений. Однако я не уверен, как должны выглядеть результаты модели. Когда я вывожу 3 вложения (anchor
, positive
, negative
), я получаю 3 значения потерь.
Я хотел бы понять, что моя модель должна вернуться сюда? У меня есть функция get batch, которая генерирует триплетные партии с привязкой / положительным / отрицательным триплетом
def triplet_loss(y_true, y_pred, alpha = 0.2):
"""
Implementation of the triplet loss as defined by formula (3)
Arguments:
y_true -- true labels, required when you define a loss in Keras, you don't need it in this function.
y_pred -- python list containing three objects:
anchor -- the encodings for the anchor images, of shape (None, 128)
positive -- the encodings for the positive images, of shape (None, 128)
negative -- the encodings for the negative images, of shape (None, 128)
/* Returns loss - Scalar value */
"""
anchor, positive, negative = y_pred[0], y_pred[1], y_pred[2]
pos_dist = tf.reduce_sum(tf.square(tf.subtract(anchor, positive)), axis=-1)
neg_dist = tf.reduce_sum(tf.square(tf.subtract(anchor, negative)), axis=-1)
basic_loss = tf.add(tf.subtract(pos_dist, neg_dist), alpha)
loss = tf.reduce_sum(tf.maximum(triplet_loss, 0))
return loss
У меня есть следующая функция, которая создает вложение для одного изображения / матрицы.
def build_embedding(self, input_shape, dimensions):
inp = Input(shape=input_shape)
/*further conv2d / dense layers
return Model(inputs=inp, outputs=out)
def build_triplets_model(self, shape, dimensions):
net = self.build_embedding(shape, dimensions)
anchor_input = Input(shape=shape, name='anchor')
positive_input = Input(shape=shape, name='positive')
negative_input = Input(shape=shape, name='negative')
# Get the embedded values
encoded_a = net(anchor_input)
encoded_p = net(positive_input)
encoded_n = net(negative_input)
triplet_net = Model(inputs=[anchor_input, positive_input, negative_input],\
outputs=[encoded_a, encoded_p, encoded_n])
return triplet_net