Как создать собственный ImageDataGenerator для многозадачной сети, выполняющей сегментацию? - PullRequest
0 голосов
/ 22 апреля 2020

Я хочу обучить модифицированную архитектуру U- net с использованием ImageDataGenerator. Сеть принимает изображения в качестве входных данных и выводит как обычную маску сегментации, так и маску сегментации в другом представлении. Таким образом, сеть дает два выхода, и в процессе обучения оба выхода должны использоваться в функции потерь. Таким образом, для каждого входного изображения необходимо иметь два основных изображения истинности. Для этого я создал собственный ImageDataGenerator в соответствии с:

def custom_generator(datagen, im_dir, mask_dir, mask_dt_dir, batch_size):
    seed = 1
    im_gen = datagen.flow_from_directory(im_dir, batch_size = batch_size, seed = seed) 
    mask_gen_1 = datagen.flow_from_directory(mask_dir, batch_size = batch_size, seed = seed)
    mask_gen_2 = datagen.flow_from_directory(mask_dt_dir, batch_size = batch_size, seed = seed)  
    while True:
        yield im_gen.next(), [mask_gen_1.next(), mask_gen_2.next()] 

Однако использование этого генератора в функции обучения

H = model.fit_generator(train_generator, epochs = num_epochs, verbose=1, steps_per_epoch = 
     math.floor(num_samples/batch_size))

дает ошибку:

AttributeError: объект 'tuple' не имеет атрибута 'shape'

Вероятно, это связано с тем, что выход генератора является кортежем. Однако я не знаю, какой тип данных следует использовать, и я видел предыдущие примеры, когда генераторы, использующие выходные данные кортежей, были успешно использованы.

Как создать собственный генератор, который работает в . fit_generator ? Есть ли альтернативный способ обучения многозадачной сети в Керасе?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...