Модифицированная модель Keras VGG16, дающая один и тот же прогноз каждый раз - PullRequest
0 голосов
/ 27 сентября 2019

Я взял встроенную модель keras.applications.vgg16.VGG16 (weights = 'imagenet', include_top = True, input_shape = (224,224,3)) и перенес обучение для набора данных PASCAL VOC 2012 с 20 классами, включая GlobalСредний уровень объединения, как показано ниже:

def VGG16_modified():
    base_model = vgg16.VGG16(include_top=True,weights='imagenet',input_shape=(224,224,3))
    print(base_model.summary())
    x = base_model.get_layer('block5_pool').output
    x = (GlobalAveragePooling2D())(x)
    predictions = Dense(20,activation='sigmoid')(x)

    final_model = Model(input = base_model.input, output = predictions)
    print(final_model.get_weights())
    return final_model

Теперь я хочу взять карту активации класса, основанную на этой бумаге.Для этого мой код приведен ниже:

 def get_CAM(model,img):
        model = load_model(model)
        im = image.load_img(img,target_size=(224,224))
        im = image.img_to_array(im)
        im = np.expand_dims(im,axis=0)
        class_weights = model.layers[-1].get_weights()[0]
        final_conv_layer = model.get_layer('block5_pool')
        cam_model = Model(inputs = model.input,outputs=(final_conv_layer.output,model.layers[-1].output))
        conv_outputs, predictions = cam_model.predict(im)
        conv_outputs = np.squeeze(conv_outputs)
        prediction = np.argmax(predictions)
        print(predictions)
        print(prediction)
        print(conv_outputs)
        print(conv_outputs.shape)
        class_weights = class_weights[:,prediction]
        mat_for_mult = scipy.ndimage.zoom(conv_outputs,(32,32,1),order=1)
        final_output = np.dot(mat_for_mult.reshape((224*224, 512)),class_weights).reshape((224,224))
        print(final_output)
        return final_output

Но cam_model.predict (im) всегда дает один и тот же класс для всех изображений.Я не уверен, где я не так с этим.Поскольку pascal voc 2012 содержит изображения с несколькими метками, я использовал 'sigmoid' в последнем слое updated_vgg16, а не softmax.Можете ли вы дать мне знать, где я ошибся.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...