Я пытаюсь классифицировать некоторые CXR-изображения, которые имеют несколько меток на образец. Из того, что я понимаю, я должен нанести плотный слой с сигмовидными активациями и использовать бинарную кроссентропию в качестве моей функции потерь. Проблема в том, что существует большой дисбаланс классов (гораздо больше норм, чем ненормальных). Мне любопытно, вот мой модельный софар:
from keras_applications.resnet_v2 import ResNet50V2
from keras.layers import GlobalAveragePooling2D, Dense
from keras import Sequential
ResNet = Sequential()
ResNet.add(ResNet50V2(input_shape=shape, include_top=False, weights=None,backend=keras.backend,
layers=keras.layers,
models=keras.models,
utils=keras.utils))
ResNet.add(GlobalAveragePooling2D(name='avg_pool'))
ResNet.add(Dense(len(label_counts), activation='sigmoid', name='Final_output'))
Как мы можем видеть, я использую сигмоид, чтобы получить вывод, но я немного запутался относительно того, как реализовать веса. Я думаю, что мне нужно использовать пользовательскую функцию потерь, которая использует BCE (use_logits = true). Примерно так:
xent = tf.losses.BinaryCrossEntropy(
from_logits=True,
reduction=tf.keras.losses.Reduction.NONE)
loss = tf.reduce_mean(xent(targets, pred) * weights))
Таким образом, он обрабатывает выходные данные как логиты, но в чем я не уверен, так это в активации окончательного вывода. Сохраняю ли я его при активации сигмовидной кишки, или я использую линейную активацию (не активирована)? Я предполагаю, что мы сохраняем сигмовидную кишку, и просто относимся к ней как к git, но я не уверен, так как пиктограммы "torch.nn.BCEWithLogitsLoss
" содержат сигмовидный слой
РЕДАКТИРОВАТЬ: Найдено это: https://www.reddit.com/r/tensorflow/comments/dflsgv/binary_cross_entropy_with_from_logits_true/
Согласно: pgaleone
from_logits = True означает, что функция потерь ожидает линейный тензор (выходной уровень вашей сети без какой-либо функции активации, кроме идентификатора), поэтому вы должны удалить сигмовидную оболочку, поскольку сама функция потерь будет применять softmax к выходу вашей сети, а затем вычислять кросс-энтропию