Прогнозирование только одного класса - PullRequest
0 голосов
/ 11 марта 2020

Я применяю трансферное обучение с использованием VGG16 для классификации бинарной классификации диабетической болезни c ретинопатии. Даже после балансировки классов моя модель предсказывает только один класс. почему это происходит. Ниже мой код

base_model=VGG16(weights='imagenet',include_top=False) #imports the mobilenet model and discards the last 1000 neuron layer.

x=base_model.output
x=GlobalAveragePooling2D()(x)
x=Dense(1024,activation='relu')(x) #we add dense layers so that the model can learn more complex functions and classify for better results.
x=Dense(1024,activation='relu')(x) #dense layer 2
x=Dense(512,activation='relu')(x) #dense layer 3
preds=Dense(1,activation='softmax')(x) #final layer with softmax activation
vgg=Model(inputs=base_model.input,outputs=preds)

1 Ответ

1 голос
/ 11 марта 2020

Похоже, вы используете функцию активации softmax на своем выходе. Softmax обычно используется, когда вы классифицируете входные данные несколькими возможными классами, так как он выводит распределение вероятностей (т. Е. Сумма всех элементов равна 1) Для этого сначала возводят в степень каждый элемент, а затем делим каждый на сумму всех элементов.

Однако, если у вас есть только одна выходная единица, тогда он должен будет всегда выводить 1, поскольку он будет вычислять exp (x_1) / exp (x_1) = 1

Для задачи бинарной классификации, которую вы выполняете, я бы рекомендовал вместо этого использовать функцию активации вывода sigmoid :

base_model=VGG16(weights='imagenet',include_top=False) #imports the mobilenet model and discards the last 1000 neuron layer.

x=base_model.output
x=GlobalAveragePooling2D()(x)
x=Dense(1024,activation='relu')(x) #we add dense layers so that the model can learn more complex functions and classify for better results.
x=Dense(1024,activation='relu')(x) #dense layer 2
x=Dense(512,activation='relu')(x) #dense layer 3
preds=Dense(1,activation='sigmoid')(x) #final layer with softmax activation
vgg=Model(inputs=base_model.input,outputs=preds)

Предполагается, что метки в вашем учебном наборе данных равны 0 и 1.

...