Модель keras с tf.contrib.losses.metric_learning.triplet_semihard_loss Ошибка утверждения - PullRequest
0 голосов
/ 01 января 2019

Я использую python 3 с anaconda и пытаюсь использовать функцию потерь tf.contrib с моделью Keras.

Код следующий

from keras.layers import Dense, Flatten
from keras.optimizers import Adam
from keras.models import Sequential
from tensorflow.contrib.losses import metric_learning
model = Sequential()
model.add(Flatten(input_shape=input_shape))
model.add(Dense(50,  activation="relu"))
model.compile(loss=metric_learning.triplet_semihard_loss, optimizer=Adam())

Я получаю следующую ошибку:

Файл "/home/user/.local/lib/python3.6/site-packages/keras/engine/training_utils.py ", строка 404, во взвешенном файле score_array = fn (y_true, y_pred)" /home/user/anaconda3/envs/siamese/lib/python3.6/site-packages/tenorflow / contrib / loss / python / metric_learning / metric_loss_ops.py ", строка 179, в triplet_semihard_loss assert lshape.shape == 1 AssertionError

Когда я использую ту же сеть с функцией потери keras, этоработает нормально, я пытался обернуть функцию потери tf в такую ​​функцию, как

def func(y_true, y_pred): 
    import tensorflow as tf
    return tf.contrib.losses.metric_learning.triplet_semihard_loss(y_true, y_pred) 

И все равно получаю ту же ошибку

Что я здесь не так делаю?

обновление: при изменении функции для возврата следующего

return K.categorical_crossentropy(y_true, y_pred)

все работает отлично!Но я не могу заставить его работать с определенной функцией потери tf ...

Когда я захожу в tf.contrib.losses.metric_learning.triplet_semihard_loss и удаляю эту строку кода: assert lshape.shape == 1 все работает нормально

Спасибо

Ответы [ 2 ]

0 голосов
/ 12 февраля 2019

Проблема в том, что вы передаете неправильный ввод в функцию потерь.

Согласно triplet_semihard_loss docstring вам необходимо передать labels и embeddings.

Итак, ваш код должен быть:

def func(y, embeddings): 
    return tf.contrib.losses.metric_learning.triplet_semihard_loss(labels=y, embeddings=embeddings) 

И еще два замечания о сети для вложений:

  1. Последний плотный слой должен быть безактивация

  2. Не забудьте нормализовать выходной вектор model.add(Lambda(lambda x: K.l2_normalize(x, axis=1)))

0 голосов
/ 01 января 2019

Похоже, что ваша проблема связана с неправильным вводом в функцию потерь.Фактически, триплетная потеря требует параметров:

Args:
labels: 1-D tf.int32 `Tensor` with shape [batch_size] of
  multiclass integer labels.
embeddings: 2-D float `Tensor` of embedding vectors. Embeddings should
  be l2 normalized.

Вы уверены, что y_true имеет правильную форму?Можете ли вы дать нам более подробную информацию о используемых вами тензорах?

...