Использование потери Хэмминга в качестве пользовательской функции потерь в модели Кераса возвращает «ValueError: None values ​​not supported». - PullRequest
0 голосов
/ 09 июля 2020

Я хотел бы запустить нейронную сеть Keras с функцией потерь Хэмминга. Ниже показано, как я создаю функцию потерь Хэмминга. Кредиты на этот пост

def custom_hamming_loss(y_true, y_pred):
    return tf.cast(tf.math.count_nonzero(tf.cast(y_true, tf.float32) - tf.cast(y_pred, tf.float32), axis=-1), tf.float32)/tf.cast(y_true, tf.float32).shape[-1]
y_train=array([[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
       [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0],
       [0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])

y_test=array([[0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0],
       [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]])

model_for_pruning.compile(optimizer=optimizer,
                          loss=custom_hamming_loss,
                          metrics=[tf.keras.metrics.CategoricalAccuracy()])

Но когда я пытаюсь вызвать model.fit (), я получаю следующую ошибку:

ValueError: None values not supported.

Что такое Я делаю что-то не так?

[ОБНОВЛЕНИЕ]

Я также нашел функцию из аддонов тензорного потока hamming_loss_fn . Равно ли это пользовательской функции, которую я опубликовал? Что ты думаешь. Вкратце, как видите, у меня много нулей и несколько единиц. Поэтому я подумал, что потеря Хэмминга или потеря фокуса пригодятся в таких ситуациях. Оба поставляются с репозиторием tf addons, хотя я хотел, чтобы работала более индивидуальная реализация.

Думаю, пока я буду придерживаться надстроек tf.

введите описание изображения здесь

[ОБНОВЛЕНИЕ - 10.07.2020]

Кажется, проблема была исправлена ​​с помощью tensorflow 2.2.0, однако я получаю ошибку градиентов, которые не могут быть обновлены. И на основе поиска, который я провел, я обнаружил, что потеря хэмминга не может хорошо работать с обратным распространением из-за операции XOR. Поэтому я полагаю, что потери Хэмминга нельзя использовать в качестве функции потерь или метрич. c функции

.
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...