как загрузить данные меток, представленные в растровом формате, в Keras / Tensorflow - PullRequest
0 голосов
/ 21 сентября 2018

Я хочу использовать сеть CNN для сегментирования объектов 2 (двоичный: «0: объект не представлен, 1: объект присутствует»), но у меня есть проблема с данными.Данные поезда представляют собой 150 изображений и в формате «jpg», а основную правду (данные метки) также составляют 150 изображений растров «png» 0 и 1 (что приводит к черно-белым изображениям).

Теперь вопрос заключается в том, как загрузить этот гибрид изображений поездов и изображений меток в Keras / Tensorflow, и если есть фиктивный пример и / или демонстрация того, как это сделать в Python, я был бы признателен.

1 Ответ

0 голосов
/ 21 сентября 2018

Вы можете определить один генератор для чтения входных изображений и другой для чтения меток, используя класс ImageDataGenerator и его метод flow_from_directory(), а затем объединить эти двагенераторы в одном генераторе.Просто убедитесь, что структура каталогов и (порядок) имен файлов ввода и меток одинаковы:

data_image_gen = ImageDataGenerator(...)
data_label_gen = ImageDataGenerator(...)

image_gen = data_image_gen.flow_from_directory(image_directory,
                # no need to return labels
                class_mode=None,
                # don't shuffle to have the same order as labels
                shuffle=False)

image_gen = data_image_gen.flow_from_directory(label_directory,
                color_mode='grayscale',
                # no need to return labels
                class_mode=None,
                # don't shuffle to have the same order as images 
                shuffle=False)

def final_gen(image_gen, label_gen):
    for data, labels in zip(image_gen, label_gen):
        # divide labels by 255 to make them like masks i.e. 0 and 1
        labels /= 255.
        # remove the last axis, i.e. (batch_size, n_rows, n_cols, 1) --> (batch_size, n_rows, n_cols)
        labels = np.squeeze(labels, axis=-1)

        yield data, labels

# ... define your model

# fit the model
model.fit_generator(final_gen(image_gen, label_gen), ...)
...