Ошибка значения функции потерь Keras: ValueError: операция имеет значение None для градиента. в сети LSTM - PullRequest
1 голос
/ 30 мая 2020

Итак, я пытаюсь обучить свою языковую модель сети LSTM и использовать функцию недоумения в качестве функции потерь, но получаю следующую ошибку:

ValueError: An operation has `None` for gradient. Please make sure that all of your ops have a gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.

Моя функция потерь выглядит следующим образом:

from keras import backend as K
def perplexity_raw(y_true, y_pred):
    """
    The perplexity metric. Why isn't this part of Keras yet?!
    https://stackoverflow.com/questions/41881308/how-to-calculate-perplexity-of-rnn-in-tensorflow
    https://github.com/keras-team/keras/issues/8267
    """
#     cross_entropy = K.sparse_categorical_crossentropy(y_true, y_pred)
    cross_entropy = K.cast(K.equal(K.max(y_true, axis=-1),
                          K.cast(K.argmax(y_pred, axis=-1), K.floatx())),
                  K.floatx())
    perplexity = K.exp(cross_entropy)
    return perplexity

и я создаю свою модель следующим образом:

# define model
model = Sequential()
model.add(Embedding(vocab_size, 500, input_length=max_length-1))
model.add(LSTM(750))
model.add(Dense(vocab_size, activation='softmax'))
print(model.summary())
# compile network
model.compile(loss=perplexity_raw, optimizer='adam', metrics=['accuracy'])
# fit network
model.fit(X, y, epochs=150, verbose=2)

Ошибка возникает, когда я пытаюсь подогнать свою модель. Кто-нибудь знает, что вызывает ошибку и как ее исправить?

1 Ответ

1 голос
/ 30 мая 2020

Это виновники: K.argmax и K.max. У них нет градиента. Я также думаю, что они вам не нужны в метрике потерь c! Это потому, что max ing и argmax ing что-то удаляет информацию о том, насколько прогноз неверен.

Я не знаю, какие потери вы хотите измерить, но я думаю, что вы ищете что-то вроде tf.exp(tf.nn.sigmoid_cross_entropy_with_logits(y_true, y_pred)) или tf.exp(tf.softmax_cross_entopy_with_logits(y_true, y_pred)). Возможно, вам придется преобразовать ваши логиты в одну горячую кодировку, используя tf.one_hot.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...