Tensorflow Keras разного количества категорий в обучении и валидации - PullRequest
1 голос
/ 28 апреля 2020

Я строю классификатор изображений, используя встроенный в Keras Re snet (добавленный в последовательный) и генераторы данных. Изображения хранятся в отдельных папках, причем папки выступают в качестве классов.

Проблема состоит в том, что в наборе обучающих данных 464 класса по сравнению с 683 в наборе проверочных данных. Таким образом, когда я запускаю:

model.fit_generator(
    train_datagen, 
    steps_per_epoch = STEP_SIZE_TRAIN,
    epochs = EPOCHS,
    verbose = 1,
    callbacks = [cp_callback, cp_tensorboard],
    validation_data = val_datagen,
    validation_freq = 2
)

, я получаю ошибку

ValueError: Error when checking target: expected dense_2 to have shape (464,) but got array with shape (683,)

Это имеет смысл; модель, обученная на тренировочном наборе, не может оценить классы, для которых у нее нет узлов. Тем не менее, возможно ли изменить мою модель или разделение набора данных для решения этой проблемы?

В качестве альтернативы, есть ли способ использовать validation_split с генераторами данных, позволяющий мне проверять, не касаясь отдельного набора данных?

1 Ответ

0 голосов
/ 04 мая 2020

Вам необходимо поддерживать DataFrame для набора проверки, который состоит только из обученных классов. Позже вы можете использовать функцию Keras ImageDataGenerator flow_from_dataframe , чтобы решить вашу проблему. Вы можете следовать приведенному ниже коду.

images = []
classes = []
#To filter classes from validation set 
for i in train_classes:
  image_list = os.listdir("Validation/" + i + "/") 
  cl = [i] * len (image_list)
  images.extend(image_list)
  classes.extend(cl)

val_df = pd.DataFrame({"Images":images, "Classes":classes})

val_datgen = ImageDataGenerator(rescale=1./255)

validation_generator = val_datagen.flow_from_dataframe(
        dataframe=val_df,
        directory='Validation',
        x_col="Images",
        y_col="Classes",
        target_size=(150, 150),
        batch_size=32,
        class_mode='categorical')

Где train_classes - это список классов в ваших тренировочных данных.

...