Ошибка при использовании tf.keras.applications.resnet50.preprocess_input внутри tf.data.Dataset.map - PullRequest
0 голосов
/ 08 апреля 2020

У меня проблемы с функцией resnet50.preprocess_input () из tenorflow.compat.v1.keras.applications.resnet50

В частности, после Несколько проб и ошибок, я могу сказать, что проблема возникает, когда внутри функции генератора набора данных происходит вызов:

dataset.map(pre_processing_image)

, где

def pre_processing_image(image):
    image = resnet50.preprocess_input(image)
    return image

и набор данных разбивается на партии. Когда я достигаю последнюю партию, независимо от того, закончена она или меньше, я получаю ошибку, похожую на

Tensor («Const: 0», shape = (3,), dtype = float32) должен быть из того же графика, что и Tensor ("BatchDatasetV2: 0", shape = (), dtype = option)

Я не могу понять, что происходит, потому что

  • Если я использую другой preprocess_input, например, mobil enet, без изменения чего-либо еще, то проблем не будет. Копая код, я обнаружил, что все эти функции вызывают эту , но mobil enet использует "mode = 'tf'", в то время как re snet должен использовать 'caffe'
  • . ошибка не связана с тем, что последняя партия меньше по сравнению с другими, я пытался сделать их равными, но ошибки продолжают возникать на последнем этапе первой эпохи обучения
  • Если я не t map , но вместо этого pre_processing_image вызывается непосредственно внутри tf.data.Dataset.from_generator проблем нет ... только код становится намного медленнее

Чтобы дать вам полный код:

def image_gen(ds_path, ds_scores=None):
    for i, path in enumerate(ds_path):
        img = im.load_img(path,
                          color_mode='rgb',
                          target_size=(NETWORK_INFO.value[1],NETWORK_INFO.value[1]),
                          interpolation='bilinear')

        img_to_numpy = np.array(img)

        if (ds_scores is not None):
            yield img_to_numpy, ds_scores[i]
        else:
            yield img_to_numpy

def pre_processing_image(image, score=None):
    image = resnet50.preprocess_input(image)

    if score is None:
        return image
    else:
        return image, score

def generator(batchsize, train=False, val=False, test=False, shuffle=False):
    with tf.Session() as sess:    
        if (train):
            dataset = tf.data.Dataset.from_generator(lambda: image_gen(train_paths, train_scores),
                                                      output_types=(tf.float32, tf.float32))
        elif(val):
            dataset = tf.data.Dataset.from_generator(lambda: image_gen(val_paths, val_scores),
                                                      output_types=(tf.float32, tf.float32))
            else:
                dataset = tf.data.Dataset.from_generator(lambda: image_gen(test_paths),
                                                          output_types=(tf.float32))          

        if (shuffle):
            dataset = dataset.shuffle(buffer_size=10*batchsize)            

        dataset = dataset.batch(batchsize)        

        dataset = dataset.map(pre_processing_image,
                                  num_parallel_calls=tf.data.experimental.AUTOTUNE)

        dataset = dataset.prefetch(buffer_size=2)

        dataset = dataset.repeat(count = -1)        

        iterable = tf.data.make_initializable_iterator(dataset)
        batch = iterable.get_next()
        sess.run(iterable.initializer)

        # yield all the time it is required
        while True:
            try:
                yield sess.run(batch)
            except tf.errors.OutOfRangeError:
                pass

Я пытался связываться с положением функции карты и параметрами shuffle / prefatch, но ничего не решило проблему. Наконец, как вы можете видеть, я использую одну и ту же функцию как для обучения, так и для генератора проверки, я просто изменяю входной параметр для выбора с набором данных, который функция должна использовать

1 Ответ

0 голосов
/ 08 апреля 2020

Решена проблема.

Я пытался найти что-то похожее, но в отношении других сетей, которые использовали одинаковую предварительную обработку изображений (например, VGG16), и выяснилось, что эти проблемы были ошибками в keras

Я обновил до последнего коммита модуль keras-application (commit, not release!) И код теперь работает без проблем

...