Con vnet Порядковая функция регрессии потери - PullRequest
1 голос
/ 09 февраля 2020

В настоящее время я пытаюсь перенести обучение с Денсом enet, обученным на Имаге enet, чтобы вывести порядковое целое число {2 <3 <4 <5 <6}. Я закодировал целевые переменные в двоичные векторы длины 4 (то есть [1,0,0,0], [1,1,0,0] и т. Д. c.), Используя <a href="https://stackoverflow.com/questions/38375401/neural-network-ordinal-classification-for-age"> этот метод . Ниже приведена архитектура моей модели:

base_model = DenseNet121(include_top=False, weights="imagenet", classes=5, input_shape=(224,224,3))
base_model.trainable = False
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dropout(0.8)(x)
preds = Dense(4, activation="sigmoid", 
axis=1),
             )(x)
model = Model(inputs=base_model.input, outputs=preds, name = "ordered_logit")

model.compile(optimizer=keras.optimizers.Nadam(), loss='binary_crossentropy', metrics=[soft_acc_multi_output])

Где 'soft_acc_multi_output' - это мой пользовательский показатель c, который выдает 1, если все записи соответствуют истинным значениям, и 0 в противном случае.

import tensorflow.keras.backend as K
def soft_acc_multi_output(y_true, y_pred):
    return K.mean(K.all(K.equal(K.cast(K.round(y_true),'int32'), K.cast(K.round(y_pred),'int32')),axis=1))

В настоящее время я использую ' binary_crossentropy ', но я понял, что это не говорит модели, что, если истинная метка равна [1,1,0,0], затем [0,9, 0,7, 0, 0,6] должны быть оштрафованы более строго, чем [0,9, 0,7, 0,6, 0], но в настоящее время штрафы за них идентичны. Как мне изменить функцию потерь, чтобы она могла распознавать эту разницу?

1 Ответ

0 голосов
/ 07 апреля 2020

Когда нам нужно использовать функцию потерь (или метри c), отличную от доступных, нам нужно создать нашу собственную пользовательскую функцию потерь и перейти к model.compile.

Метри c функция аналогична функции потерь, за исключением того, что результаты оценки метри c не используются при обучении модели. Мы можем использовать любую из функций потерь как функцию metri c.

Итак, создайте пользовательскую функцию потерь, которая использует построенный вами расчет для оценки метри c, чтобы определить дополнительный штраф или убыток, который будет оптимизирован моделью.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...