В приведенном ниже коде я получаю следующую ошибку: TypeError: fit_generator() missing 1 required positional argument: 'generator'
.Это будет первый аргумент в коде, верно?Может кто-нибудь объяснить, почему я все еще получаю эту ошибку и как ее решить?
x_train имеет форму (400, 256, 256, 4) dtype = float64.y_train имеет форму (400, 256, 256) dtype = uint8.
x_val имеет форму (100, 256, 256, 4) dtype = float64.y_val - это форма (100, 256, 256) dtype = uint8.
# Create image generator
data_gen_args = dict(rotation_range=5,
width_shift_range=0.1,
height_shift_range=0.1,
validation_split=0.2)
image_datagen = ImageDataGenerator(**data_gen_args)
seed = 1
batch_size = 4
def XYaugmentGenerator(X1, y, seed, batch_size):
genX1 = gen.flow(X1, y, batch_size=batch_size, seed=seed)
genX2 = gen.flow(y, X1, batch_size=batch_size, seed=seed)
while True:
X1i = genX1.next()
X2i = genX2.next()
yield X1i[0], X2i[0]
# Train model
Model.fit_generator(XYaugmentGenerator(x_train, y_train, seed, batch_size), steps_per_epoch=np.ceil(float(len(images)) / float(batch_size)),
validation_data = XYaugmentGenerator(x_val, y_val,seed, batch_size),
validation_steps = np.ceil(float(len(x_val)) / float(batch_size))
, shuffle=True, epochs=20)
Полная ошибка при трассировке приведена ниже:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-68-54be54c1b7c7> in <module>()
30 train_generator,
31 steps_per_epoch=20,
---> 32 epochs=1)
~/anaconda3/lib/python3.6/site-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
89 warnings.warn('Update your `' + object_name +
90 '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 91 return func(*args, **kwargs)
92 wrapper._original_function = func
93 return wrapper
TypeError: fit_generator() missing 1 required positional argument: 'generator'