Я изучаю классификацию изображений в керасе. Я скачал образец набора пончиков и вафель, но они различаются по размеру. Чтобы стандартизировать их размер, я загружаю изображения из их каталогов, изменяю их размер и сохраняю их в виде массивов:
test_data_dir = 'v_data/train/donuts_and_waffles/'
validation_data_dir = 'v_data/test/donuts_and_waffles/'
loaded_test_donuts = list()
for filename in listdir(test_data_dir + 'donuts/'):
image1 = Image.open(test_data_dir + 'donuts/' + filename)
img_resized = image1.resize((224,224))
img_data = asarray(img_resized)
loaded_test_donuts.append(img_data)
loaded_test_waffles = list()
for filename in listdir(test_data_dir + 'waffles/'):
image1 = Image.open(test_data_dir + 'waffles/' + filename)
img_resized = image1.resize((224,224))
img_data = asarray(img_resized)
loaded_test_waffles.append(img_data)
loaded_validation_donuts = list()
for filename in listdir(validation_data_dir + 'donuts/'):
image1 = Image.open(validation_data_dir + 'donuts/' + filename)
img_resized = image1.resize((224,224))
img_data = asarray(img_resized)
loaded_validation_donuts.append(img_data)
loaded_validation_waffles = list()
for filename in listdir(validation_data_dir + 'waffles/'):
image1 = Image.open(validation_data_dir + 'waffles/' + filename)
img_resized = image1.resize((224,224))
img_data = asarray(img_resized)
loaded_validation_waffles.append(img_data)
test_data = list()
validation_data = list()
test_data.append(np.array(loaded_test_donuts))
test_data.append(np.array(loaded_test_waffles))
validation_data.append(np.array(loaded_validation_donuts))
validation_data.append(np.array(loaded_validation_waffles))
test_data = np.array(test_data)
validation_data = np.array(validation_data)
Затем я хочу создать ImageDataGenerator для моих данных:
train_datagen = ImageDataGenerator(
rescale=1. / 255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
test_datagen = ImageDataGenerator(rescale=1. / 255)
train_generator = train_datagen.flow(
#how can I pass here test_data to make it work (along with which parameters)
)
validation_generator = test_datagen.flow(
#how can I pass here validation_data to make it work (along with which parameters)
)
Как этого добиться? Я пытался так:
train_generator = train_datagen.flow(
test_data, #does not work
batch_size=batch_size)
validation_generator = test_datagen.flow(
validation_data, #does not work
batch_size=batch_size)
но потом я получаю эту ошибку:
Traceback (most recent call last):
...
ValueError: ('Input data in `NumpyArrayIterator` should have rank 4. You passed an array with shape', (2, 770, 224, 224, 3))