Я использую python 3 с anaconda и пытаюсь использовать функцию потерь tf.contrib с моделью Keras.
Код следующий
from keras.layers import Dense, Flatten
from keras.optimizers import Adam
from keras.models import Sequential
from tensorflow.contrib.losses import metric_learning
model = Sequential()
model.add(Flatten(input_shape=input_shape))
model.add(Dense(50, activation="relu"))
model.compile(loss=metric_learning.triplet_semihard_loss, optimizer=Adam())
Я получаю следующую ошибку:
Файл "/home/user/.local/lib/python3.6/site-packages/keras/engine/training_utils.py ", строка 404, во взвешенном файле score_array = fn (y_true, y_pred)" /home/user/anaconda3/envs/siamese/lib/python3.6/site-packages/tenorflow / contrib / loss / python / metric_learning / metric_loss_ops.py ", строка 179, в triplet_semihard_loss assert lshape.shape == 1 AssertionError
Когда я использую ту же сеть с функцией потери keras, этоработает нормально, я пытался обернуть функцию потери tf в такую функцию, как
def func(y_true, y_pred):
import tensorflow as tf
return tf.contrib.losses.metric_learning.triplet_semihard_loss(y_true, y_pred)
И все равно получаю ту же ошибку
Что я здесь не так делаю?
обновление: при изменении функции для возврата следующего
return K.categorical_crossentropy(y_true, y_pred)
все работает отлично!Но я не могу заставить его работать с определенной функцией потери tf ...
Когда я захожу в tf.contrib.losses.metric_learning.triplet_semihard_loss и удаляю эту строку кода: assert lshape.shape == 1
все работает нормально
Спасибо