ошибка keras при использовании кастомной потери - PullRequest
0 голосов
/ 21 мая 2018

Я должен был использовать простую модель BiLSTM с моей собственной функцией потерь в Keras.См. Ниже.

model = Sequential()
model.add(Bidirectional(LSTM(128, return_sequences=True), input_shape=(1,8)))
model.add(Bidirectional(LSTM(128)))
model.add(Dense(64, activation='relu'))
model.add(Dense(20, activation='softmax'))

def my_loss_np(y_true, y_pred):

    labels = [np.argmax(y_pred[i]) for i in range(y_pred.shape[1])]

    loss = np.mean(labels)
    return loss

import keras.backend as K
def my_loss(y_true, y_pred):
    loss = K.eval(my_loss_np(K.eval(y_true), K.eval(y_pred)))
    return loss

Когда я компилирую эту модель, я получаю сообщение об ошибке -

model.compile(loss=my_loss, optimizer='adam')

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'dense_95_target' with dtype float and shape [?,?]
     [[Node: dense_95_target = Placeholder[dtype=DT_FLOAT, shape=[?,?], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

1 Ответ

0 голосов
/ 21 мая 2018

Здесь есть несколько проблем с вашей функцией потерь:

  1. Вы используете NumPy на тензорах, к сожалению, хотя это интуитивно понятно, это не работает.Вам нужно использовать тензорные операторы из Keras backend , они очень похожи.
  2. Для этого вы звоните K.eval, но на этом этапе вы все еще создаете символ расчетный граф, который будет запущен в TensorFlow или Theano.Таким образом, у тензоров нет значения для вычисления, скажем, вам нужно сохранить его символическим, вы можете получить любые значения, как вы делаете в NumPy.
  3. Даже если вы решите описанные выше проблемы, вы используетенедифференцируемая операция argmax, которая не будет работать с алгоритмами градиентного спуска.

Ваша модель выглядит как задача классификации с несколькими метками, 20 классов, поскольку ваш последний уровень равен 20 с softmax.В этом случае литература использует потери categorical-crossentropy для обучения сети классификатора.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...