Используйте кнопку I Python Widget для вызова функции обучения Keras. - PullRequest
0 голосов
/ 03 февраля 2020

Я хотел бы использовать кнопку i python для запуска функции, которая обучает модель глубокого обучения с использованием Keras fit.generator () и ImageDataGenerator (). Я пытался использовать lambda для передачи аргументов функции, но он возвращает TypeError: expected str, bytes or os.PathLike object, not Button.

Код:

def trainGenerator(batch_size,train_path,image_folder,mask_folder,aug_dict,image_color_mode = "grayscale",
                    mask_color_mode = "grayscale",image_save_prefix  = "image",mask_save_prefix  = "mask",
                    flag_multi_class = False,num_class = 2,save_to_dir = None,target_size = (256,256),seed = 1):
    image_datagen = ImageDataGenerator(**aug_dict)
    mask_datagen = ImageDataGenerator(**aug_dict)
    image_generator = image_datagen.flow_from_directory(
        train_path,
        classes = [image_folder],
        class_mode = None,
        color_mode = image_color_mode,
        target_size = target_size,
        batch_size = batch_size,
        save_to_dir = save_to_dir,
        save_prefix  = image_save_prefix,
        seed = seed)
    mask_generator = mask_datagen.flow_from_directory(
        train_path,
        classes = [mask_folder],
        class_mode = None,
        color_mode = mask_color_mode,
        target_size = target_size,
        batch_size = batch_size,
        save_to_dir = save_to_dir,
        save_prefix  = mask_save_prefix,
        seed = seed)
    train_generator = zip(image_generator, mask_generator)
    for (img,mask) in train_generator:
        img,mask = adjustData(img,mask,flag_multi_class,num_class)
        yield (img,mask)

def segmentation_training(trainfolder, modelname):
    data_gen_args = dict(rotation_range=0.1,
                        width_shift_range=[0.0, 0, 0.5],
                        height_shift_range=[0.0, 0, 0.5],
                        zoom_range=[0.5,1],
                        horizontal_flip=True,
                        fill_mode='nearest')   
    myGene = trainGenerator(2,trainfolder,'image','label',data_gen_args,save_to_dir = None)
    model = unet()
    model_checkpoint = ModelCheckpoint(os.path.join('Models',modelname+'.hdf5'), monitor='loss',verbose=1, save_best_only=True)
    model.fit_generator(myGene,steps_per_epoch=3,epochs=1,callbacks=[model_checkpoint])

modelname = "test"
trainfolder = Path('Data/Segmentation/dataset/train')
btn = widgets.Button(description="Run")
btn.on_click(lambda trainfolder=trainfolder, modelname=modelname : segmentation_training(trainfolder,modelname))
display(btn)

Ошибка:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-41-d4282548b872> in <lambda>(trainfolder, modelname)
     46 trainfolder = Path('Data/Segmentation/dataset/train')
     47 btn = widgets.Button(description="Run")
---> 48 btn.on_click(lambda trainfolder=trainfolder, modelname=modelname : segmentation_training(trainfolder,modelname))
     49 display(btn)

<ipython-input-41-d4282548b872> in segmentation_training(trainfolder, modelname)
     40     model = unet()
     41     model_checkpoint = ModelCheckpoint(os.path.join('Models',modelname+'.hdf5'), monitor='loss',verbose=1, save_best_only=True)
---> 42     model.fit_generator(myGene,steps_per_epoch=3,epochs=1,callbacks=[model_checkpoint])
     43 
     44 

~/virtualenv/lib/python3.6/site-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
     89                 warnings.warn('Update your `' + object_name +
     90                               '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 91             return func(*args, **kwargs)
     92         wrapper._original_function = func
     93         return wrapper

~/virtualenv/lib/python3.6/site-packages/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
   1413             use_multiprocessing=use_multiprocessing,
   1414             shuffle=shuffle,
-> 1415             initial_epoch=initial_epoch)
   1416 
   1417     @interfaces.legacy_generator_methods_support

~/virtualenv/lib/python3.6/site-packages/keras/engine/training_generator.py in fit_generator(model, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
    175             batch_index = 0
    176             while steps_done < steps_per_epoch:
--> 177                 generator_output = next(output_generator)
    178 
    179                 if not hasattr(generator_output, '__len__'):

~/virtualenv/lib/python3.6/site-packages/keras/utils/data_utils.py in get(self)
    791             success, value = self.queue.get()
    792             if not success:
--> 793                 six.reraise(value.__class__, value, value.__traceback__)

~/virtualenv/lib/python3.6/site-packages/six.py in reraise(tp, value, tb)
    691             if value.__traceback__ is not tb:
    692                 raise value.with_traceback(tb)
--> 693             raise value
    694         finally:
    695             value = None

~/virtualenv/lib/python3.6/site-packages/keras/utils/data_utils.py in _data_generator_task(self)
    656                             # => Serialize calls to
    657                             # infinite iterator/generator's next() function
--> 658                             generator_output = next(self._generator)
    659                             self.queue.put((True, generator_output))
    660                         else:

<ipython-input-41-d4282548b872> in trainGenerator(batch_size, train_path, image_folder, mask_folder, aug_dict, image_color_mode, mask_color_mode, image_save_prefix, mask_save_prefix, flag_multi_class, num_class, save_to_dir, target_size, seed)
     13         save_to_dir = save_to_dir,
     14         save_prefix  = image_save_prefix,
---> 15         seed = seed)
     16     mask_generator = mask_datagen.flow_from_directory(
     17         train_path,

~/virtualenv/lib/python3.6/site-packages/keras_preprocessing/image.py in flow_from_directory(self, directory, target_size, color_mode, classes, class_mode, batch_size, shuffle, seed, save_to_dir, save_prefix, save_format, follow_links, subset, interpolation)
    962             follow_links=follow_links,
    963             subset=subset,
--> 964             interpolation=interpolation)
    965 
    966     def standardize(self, x):

~/virtualenv/lib/python3.6/site-packages/keras_preprocessing/image.py in __init__(self, directory, image_data_generator, target_size, color_mode, classes, class_mode, batch_size, shuffle, seed, data_format, save_to_dir, save_prefix, save_format, follow_links, subset, interpolation)
   1731         self.samples = sum(pool.map(function_partial,
   1732                                     (os.path.join(directory, subdir)
-> 1733                                      for subdir in classes)))
   1734 
   1735         print('Found %d images belonging to %d classes.' %

/usr/lib/python3.6/multiprocessing/pool.py in map(self, func, iterable, chunksize)
    264         in a list that is returned.
    265         '''
--> 266         return self._map_async(func, iterable, mapstar, chunksize).get()
    267 
    268     def starmap(self, func, iterable, chunksize=None):

/usr/lib/python3.6/multiprocessing/pool.py in _map_async(self, func, iterable, mapper, chunksize, callback, error_callback)
    374             raise ValueError("Pool not running")
    375         if not hasattr(iterable, '__len__'):
--> 376             iterable = list(iterable)
    377 
    378         if chunksize is None:

~/virtualenv/lib/python3.6/site-packages/keras_preprocessing/image.py in <genexpr>(.0)
   1731         self.samples = sum(pool.map(function_partial,
   1732                                     (os.path.join(directory, subdir)
-> 1733                                      for subdir in classes)))
   1734 
   1735         print('Found %d images belonging to %d classes.' %

/usr/lib/python3.6/posixpath.py in join(a, *p)
     78     will be discarded.  An empty last part will result in a path that
     79     ends with a separator."""
---> 80     a = os.fspath(a)
     81     sep = _get_sep(a)
     82     path = a

TypeError: expected str, bytes or os.PathLike object, not Button

Когда я запускаю segmentation_train(trainpath,modelname) без реализации кнопки, все работает нормально. Как я могу вызвать функцию, нажав кнопку? Заранее спасибо

1 Ответ

0 голосов
/ 03 февраля 2020

Ваш lambda привязан к классу Button, в который он был передан, , который неявно сделал первый параметр самим объектом Button. В результате параметр trainpath был фактически переименованный btn экземпляр Button. Функции, которые пытались использовать trainpath в качестве строки пути к файлу, были сбиты с толку и поэтому выдавали ошибку.

Если вы хотите сохранить лямбду, просто добавьте self в качестве первого параметра, а затем проигнорируйте его:

btn.on_click(lambda self, trainfolder=trainfolder, modelname=modelname : segmentation_training(trainfolder,modelname))

В противном случае, есть другая предлагаемая реализация с использованием functools и вызов функции с явными параметрами:

import functools 

def click_func(trainfolder,modelname):
    segmentation_training(trainfolder,modelname)

btn.on_click(functools.partial(click_func,trainfolder=trainfolder,modelname=modelname))
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...