Моя модель берет пару изображений в качестве входных данных и должна вывести карту изменений (изображение). Для этого я объединил генераторы данных для 2 входных изображений, но я не могу дать соответствующий выходной генератор для модели. Ниже приведен код, который, как я думал, будет работать, возвращая все X1, X2 и X3. У меня есть 14 тренировочных изображений и размер изображения (256, 256,3), а выходная форма - (256, 256,1)
data_gen_args = dict(
zoom_range=1.8,
horizontal_flip=True,
vertical_flip=True,
rescale=1./255,
data_format='channels_last',
validation_split=0.2)
image_datagen = ImageDataGenerator(**data_gen_args)
seed = 1
''' REMEMBER!
While using .flow_from_directory, the path needs to be to a folder, that contains folders, that contains the images '''
def combine_generators(generator, A_trainx, B_trainx, y_train, batch_size):
gen1 = generator.flow(x=A_trainx,
y=y_train,
seed=seed,
batch_size=batch_size,
shuffle=False)
gen2 = generator.flow(x=B_trainx,
y=y_train,
seed=seed,
batch_size=batch_size,
shuffle=False)
gen3 = generator.flow(x=y_train,
y=None,
seed=seed,
batch_size=batch_size,
shuffle=False)
while True:
X1 = gen1.next()
X2 = gen2.next()
X3 = gen3.next()
yield [X1[0], X2[0]], X3[0]
Но когда я запускаю следующий код, где A_trainx, B_trainx и y_train все массивы,
batch_size = 16
model = unet_modified()
model.fit_generator(inputgenerator, steps_per_epoch=20, epochs=20, shuffle=True)
Я получаю следующую ошибку. ПОЖАЛУЙСТА, ПОМОГИТЕ МНЕ!
/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:90: UserWarning: Update your `Model` call to the Keras 2 API: `Model(inputs=[<tf.Tenso..., outputs=Tensor("ou...)`
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-19-76cdb8233193> in <module>()
8 model = unet_modified()
9
---> 10 model.fit_generator(inputgenerator, steps_per_epoch=20, epochs=20, shuffle=True)
11
12 model.save(filepath)
5 frames
/usr/local/lib/python3.6/dist-packages/keras/engine/training_utils.py in check_array_length_consistency(inputs, targets, weights)
242 'the same number of samples as target arrays. '
243 'Found ' + str(list(set_x)[0]) + ' input samples '
--> 244 'and ' + str(list(set_y)[0]) + ' target samples.')
245 if len(set_w) > 1:
246 raise ValueError('All sample_weight arrays should have '
ValueError: Input arrays should have the same number of samples as target arrays. Found 14 input samples and 256 target samples.