слой свертки как выходной слой для задачи классификации - PullRequest
0 голосов
/ 03 мая 2020

Я пытаюсь построить модель глубокого обучения для классификации набора данных cifar10 из 10 классов. Теперь я хочу, чтобы в качестве выходного слоя использовался слой свертки, и этот слой (filters = 10) должен принять входные данные из сглаживания и предсказать мой класс.

код моей модели

num_class = 10

model = Sequential()
model.add(Conv2D(32, (3, 3), padding='same',
                 input_shape=x_train.shape[1:]))
model.add(Activation('relu'))
model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Conv2D(64, (3, 3), padding='same'))
model.add(Activation('relu'))
model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))


model.add(Flatten())
model.add(Dense(num_classes))
model.add(Conv2D(10, (3,3)))
model.add(Activation('softmax'))

, но это дает мне ошибку

Input 0 of layer conv2d_34 is incompatible with the layer: expected ndim=4, found ndim=2. Full shape received: [None, 6272]

как мне этого добиться?

1 Ответ

2 голосов
/ 03 мая 2020

Вы используете слой Flatten перед сверточным слоем. Flatten делает тензорный вывод 2-ым, но Conv2D требует 4-ые данные. Просто прокомментируйте линию слоя Flatten и все будет работать нормально.

В вашей модели нет модуля классификации, вам нужно иметь слой Dense с количеством классов в последнем слое.

#model.add(Flatten()) # comment this line
model.add(Dropout(0.5))
model.add(Conv2D(10,(3,3)))
model.add(Flatten())
model.add(Dense(num_class)) # num_class is how many classes do you have in your dataset
model.add(Activation('softmax'))

Вы можете использовать слой свертки в качестве окончательного результата при некотором глобальном пуле. Например, следующая модель использует GlobalAveragePooling.

model = Sequential()
model.add(Conv2D(32, (3, 3), padding='same',
                 input_shape=x_train.shape[1:]))
model.add(Activation('relu'))
model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Conv2D(64, (3, 3), padding='same'))
model.add(Activation('relu'))
model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))


model.add(Conv2D(10, (3,3)))
model.add(GlobalAveragePooling2D())
model.add(Activation('softmax'))
model.summary()
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...