Я хотел бы использовать кнопку i python для запуска функции, которая обучает модель глубокого обучения с использованием Keras fit.generator () и ImageDataGenerator (). Я пытался использовать lambda для передачи аргументов функции, но он возвращает TypeError: expected str, bytes or os.PathLike object, not Button.
Код:
def trainGenerator(batch_size,train_path,image_folder,mask_folder,aug_dict,image_color_mode = "grayscale",
mask_color_mode = "grayscale",image_save_prefix = "image",mask_save_prefix = "mask",
flag_multi_class = False,num_class = 2,save_to_dir = None,target_size = (256,256),seed = 1):
image_datagen = ImageDataGenerator(**aug_dict)
mask_datagen = ImageDataGenerator(**aug_dict)
image_generator = image_datagen.flow_from_directory(
train_path,
classes = [image_folder],
class_mode = None,
color_mode = image_color_mode,
target_size = target_size,
batch_size = batch_size,
save_to_dir = save_to_dir,
save_prefix = image_save_prefix,
seed = seed)
mask_generator = mask_datagen.flow_from_directory(
train_path,
classes = [mask_folder],
class_mode = None,
color_mode = mask_color_mode,
target_size = target_size,
batch_size = batch_size,
save_to_dir = save_to_dir,
save_prefix = mask_save_prefix,
seed = seed)
train_generator = zip(image_generator, mask_generator)
for (img,mask) in train_generator:
img,mask = adjustData(img,mask,flag_multi_class,num_class)
yield (img,mask)
def segmentation_training(trainfolder, modelname):
data_gen_args = dict(rotation_range=0.1,
width_shift_range=[0.0, 0, 0.5],
height_shift_range=[0.0, 0, 0.5],
zoom_range=[0.5,1],
horizontal_flip=True,
fill_mode='nearest')
myGene = trainGenerator(2,trainfolder,'image','label',data_gen_args,save_to_dir = None)
model = unet()
model_checkpoint = ModelCheckpoint(os.path.join('Models',modelname+'.hdf5'), monitor='loss',verbose=1, save_best_only=True)
model.fit_generator(myGene,steps_per_epoch=3,epochs=1,callbacks=[model_checkpoint])
modelname = "test"
trainfolder = Path('Data/Segmentation/dataset/train')
btn = widgets.Button(description="Run")
btn.on_click(lambda trainfolder=trainfolder, modelname=modelname : segmentation_training(trainfolder,modelname))
display(btn)
Ошибка:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-41-d4282548b872> in <lambda>(trainfolder, modelname)
46 trainfolder = Path('Data/Segmentation/dataset/train')
47 btn = widgets.Button(description="Run")
---> 48 btn.on_click(lambda trainfolder=trainfolder, modelname=modelname : segmentation_training(trainfolder,modelname))
49 display(btn)
<ipython-input-41-d4282548b872> in segmentation_training(trainfolder, modelname)
40 model = unet()
41 model_checkpoint = ModelCheckpoint(os.path.join('Models',modelname+'.hdf5'), monitor='loss',verbose=1, save_best_only=True)
---> 42 model.fit_generator(myGene,steps_per_epoch=3,epochs=1,callbacks=[model_checkpoint])
43
44
~/virtualenv/lib/python3.6/site-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
89 warnings.warn('Update your `' + object_name +
90 '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 91 return func(*args, **kwargs)
92 wrapper._original_function = func
93 return wrapper
~/virtualenv/lib/python3.6/site-packages/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
1413 use_multiprocessing=use_multiprocessing,
1414 shuffle=shuffle,
-> 1415 initial_epoch=initial_epoch)
1416
1417 @interfaces.legacy_generator_methods_support
~/virtualenv/lib/python3.6/site-packages/keras/engine/training_generator.py in fit_generator(model, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
175 batch_index = 0
176 while steps_done < steps_per_epoch:
--> 177 generator_output = next(output_generator)
178
179 if not hasattr(generator_output, '__len__'):
~/virtualenv/lib/python3.6/site-packages/keras/utils/data_utils.py in get(self)
791 success, value = self.queue.get()
792 if not success:
--> 793 six.reraise(value.__class__, value, value.__traceback__)
~/virtualenv/lib/python3.6/site-packages/six.py in reraise(tp, value, tb)
691 if value.__traceback__ is not tb:
692 raise value.with_traceback(tb)
--> 693 raise value
694 finally:
695 value = None
~/virtualenv/lib/python3.6/site-packages/keras/utils/data_utils.py in _data_generator_task(self)
656 # => Serialize calls to
657 # infinite iterator/generator's next() function
--> 658 generator_output = next(self._generator)
659 self.queue.put((True, generator_output))
660 else:
<ipython-input-41-d4282548b872> in trainGenerator(batch_size, train_path, image_folder, mask_folder, aug_dict, image_color_mode, mask_color_mode, image_save_prefix, mask_save_prefix, flag_multi_class, num_class, save_to_dir, target_size, seed)
13 save_to_dir = save_to_dir,
14 save_prefix = image_save_prefix,
---> 15 seed = seed)
16 mask_generator = mask_datagen.flow_from_directory(
17 train_path,
~/virtualenv/lib/python3.6/site-packages/keras_preprocessing/image.py in flow_from_directory(self, directory, target_size, color_mode, classes, class_mode, batch_size, shuffle, seed, save_to_dir, save_prefix, save_format, follow_links, subset, interpolation)
962 follow_links=follow_links,
963 subset=subset,
--> 964 interpolation=interpolation)
965
966 def standardize(self, x):
~/virtualenv/lib/python3.6/site-packages/keras_preprocessing/image.py in __init__(self, directory, image_data_generator, target_size, color_mode, classes, class_mode, batch_size, shuffle, seed, data_format, save_to_dir, save_prefix, save_format, follow_links, subset, interpolation)
1731 self.samples = sum(pool.map(function_partial,
1732 (os.path.join(directory, subdir)
-> 1733 for subdir in classes)))
1734
1735 print('Found %d images belonging to %d classes.' %
/usr/lib/python3.6/multiprocessing/pool.py in map(self, func, iterable, chunksize)
264 in a list that is returned.
265 '''
--> 266 return self._map_async(func, iterable, mapstar, chunksize).get()
267
268 def starmap(self, func, iterable, chunksize=None):
/usr/lib/python3.6/multiprocessing/pool.py in _map_async(self, func, iterable, mapper, chunksize, callback, error_callback)
374 raise ValueError("Pool not running")
375 if not hasattr(iterable, '__len__'):
--> 376 iterable = list(iterable)
377
378 if chunksize is None:
~/virtualenv/lib/python3.6/site-packages/keras_preprocessing/image.py in <genexpr>(.0)
1731 self.samples = sum(pool.map(function_partial,
1732 (os.path.join(directory, subdir)
-> 1733 for subdir in classes)))
1734
1735 print('Found %d images belonging to %d classes.' %
/usr/lib/python3.6/posixpath.py in join(a, *p)
78 will be discarded. An empty last part will result in a path that
79 ends with a separator."""
---> 80 a = os.fspath(a)
81 sep = _get_sep(a)
82 path = a
TypeError: expected str, bytes or os.PathLike object, not Button
Когда я запускаю segmentation_train(trainpath,modelname)
без реализации кнопки, все работает нормально. Как я могу вызвать функцию, нажав кнопку? Заранее спасибо