Странная ошибка с обучением пользовательских изображений с Xception CNN - PullRequest
1 голос
/ 11 октября 2019

У меня есть серия небольших изображений, и я хотел бы провести тренинг по XCeption CNN, используя их. Набор для обучения и проверки имеет, соответственно, следующую форму:

>(63787, 72, 72, 3) 

>(155, 72, 72, 3)

Другими словами, мои изображения соответствуют требованиям Xception минимальной формы 71,71,3 для ввода Xceptions.

Так я строю модель

    base_model=xception.Xception(include_top=False, weights=None, input_shape=(72, 72, 3))
    x = base_model.output
    x = Dense(256, activation='relu')(x)
    predictions = Dense(classes, activation='softmax')(x)
    model = Model(inputs=base_model.input, outputs=predictions)
    opt= SGD(lr=0.0001, decay=1e-6, momentum=0.9, nesterov=True)
    model.compile(loss="sparse_categorical_crossentropy", optimizer=opt, metrics=["accuracy"])

Так я тренирую модель

checkpoint = ModelCheckpoint(filename, monitor='val_acc', verbose=1, save_best_only=True, save_weights_only=False, mode='max', period=self.__model_history_period)
lr_reducer = ReduceLROnPlateau(factor=np.sqrt(0.1), cooldown=0, patience=1000, min_lr=0.5e-6)
early_stopper = EarlyStopping(min_delta=0.0001, patience=10000)
callbacks_list = [checkpoint, lr_reducer, early_stopper]

datagen = ImageDataGenerator(width_shift_range=0.1, height_shift_range=0.1, horizontal_flip=True, vertical_flip=True) 
datagen.fit(x_train)
 # Fit the model on the batches generated by datagen.flow().
H=model.fit_generator(datagen.flow(X_train, y_train,batch_size=self.__bs), steps_per_epoch=X_train.shape[0] // self.__bs, validation_data=(x_val, y_val), epochs=self.__epochs, verbose=1, max_q_size=100, callbacks=[checkpoint, lr_reducer, early_stopper])   

Однако, когда я запускаю тренинг CNN, у меня возникает следующая ошибка:

> Traceback (most recent call last):
  File "esperimento_paper.py", line 86, in <module>
    vgg.run_2D()
  File "Desktop/PhD-Market-Nets/src/classes/VggHandler.py", line 662, in run_2D
    model, H, n_epochs = self.__train_2D(x_train=x_train, y_train=y_train, x_val=x_val, y_val=y_val, index_net=index_net, index_walk=index_walk)
  File "Desktop/PhD-Market-Nets/src/classes/VggHandler.py", line 267, in __train_2D
    callbacks=[checkpoint, lr_reducer, early_stopper])  
  File "Desktop/PhD-Market-Nets/venv/lib/python3.6/site-packages/keras/legacy/interfaces.py", line 91, in wrapper
    return func(*args, **kwargs)
  File "Desktop/PhD-Market-Nets/venv/lib/python3.6/site-packages/keras/engine/training.py", line 1418, in fit_generator
    initial_epoch=initial_epoch)
  File "Desktop/PhD-Market-Nets/venv/lib/python3.6/site-packages/keras/engine/training_generator.py", line 144, in fit_generator
    val_x, val_y, val_sample_weight)
  File "Desktop/PhD-Market-Nets/venv/lib/python3.6/site-packages/keras/engine/training.py", line 789, in _standardize_user_data
    exception_prefix='target')
  File "Desktop/PhD-Market-Nets/venv/lib/python3.6/site-packages/keras/engine/training_utils.py", line 128, in standardize_input_data
    'with shape ' + str(data_shape))
ValueError: Error when checking target: expected dense_2 to have 4 dimensions, but got array with shape (155, 1)

Что мне здесь не хватает?

1 Ответ

3 голосов
/ 11 октября 2019

Проблема заключается в этой строке:

x = base_model.output
x = Dense(256, activation='relu')(x)

Когда вы установите include_top=False - то, что возвращается, это карта объектов формы [number of examples, h, w, number of features]. Когда вы применяете Dense к этой карте объектов - вы применяете ее только к последнему измерению (это похоже на свертки 1x1). Вот почему на выходе получается 4D. Чтобы преодолеть это, попробуйте:

x = base_model.output
x = Flatten()(x)
x = Dense(256, activation='relu')(x)

Flatten раздавит измерения h, w и number of features в одно измерение. Благодаря этому ваша сеть должна работать нормально.

PS. Вы также можете попробовать использовать GlobalMaxPooling2D или его среднюю версию. Это позволит пропустить пространственное положение фильтров, но существенно снизит объем памяти модели.

...