Могу ли я сгенерировать метку uint8, используя ImageDataGenerator () и flow_from_directory () в Keras? - PullRequest
0 голосов
/ 11 сентября 2018

Я имею дело с 2D-семантической сегментацией .

в Документах Keras API. В них есть только примеры, показывающие, как упорядочить набор данных для классификации изображений, а не семантической сегментации.

Итак, я расположил свое изображение и метку следующим образом

SEED = 111
batch_size = 2
image_datagen = ImageDataGenerator(
    horizontal_flip=True,
    zca_epsilon=9,
    # fill_mode='nearest',
)
image_generator = image_datagen.flow_from_directory(
    directory="/xxx/images",
    class_mode=None,
    batch_size=batch_size,
    seed=SEED,
)


def preprocessing_function(image):
    return image.astype(np.uint8)


label_datagen = ImageDataGenerator(
    horizontal_flip=True,
    zca_epsilon=9,
    rescale=1,
    preprocessing_function=preprocessing_function,
    # fill_mode='nearest',
)
label_generator = image_datagen.flow_from_directory(
    directory="/xxx/labels",
    class_mode=None,
    batch_size=batch_size,
    seed=SEED,
)

train_generator = zip(image_generator, label_generator)
print(len(image_generator))
i = 0
for image_batch, label_batch in iter(train_generator):
    print(image_batch.shape, label_batch.shape) # (2, 256, 256, 3) (2, 256, 256, 3)
    print(image_batch.dtype, label_batch.dtype) # float32 float32
    i += 1
    if i == 5:
        break

Но Я обнаружил, что тип сгенерированных изображений меток: float32 , поэтому Я добавляю функцию preprocessing_functionto label_datagen только для приведения dtype к uint8, , но dtype сгенерированных изображений меток по-прежнему float32 , казалось, что preprocessing_function ничего не сделала.

Как я могу исправить эту проблему?

Как изменить данные моей метки на uint8?

Является ли "обычной практикой" добавление функции предварительной обработки для приведения изображений метки dtype?

Спасибо за любыесовет!

1 Ответ

0 голосов
/ 17 июня 2019

Я встретил ту же проблему и обернул генератор в другой. Это работает, но это своего рода kludge

label_generator = (x.astype(np.uint8) for x in label_generator)
train_generator = zip(image_generator, label_generator)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...