Увеличение данных с помощью функции python с API tf.Dataset - PullRequest
0 голосов
/ 12 июня 2018

Я ищу динамически считываемые изображения и применяю увеличение данных для моей проблемы сегментации изображений.Из того, что я выглядел до сих пор, лучшим способом был бы tf.Dataset API с функцией .map.

Однако из примеров, которые я видел, я думаю, что мне придется адаптировать все свои функциив тензор потока (используйте tf.cond вместо if и т. д.).Проблема в том, что у меня есть некоторые действительно сложные функции, которые мне нужно применить.Поэтому я подумывал об использовании tf.py_func следующим образом:

import tensorflow as tf

img_path_list = [...]   # List of paths to read
mask_path_list = [...]  # List of paths to read

dataset = tf.data.Dataset.from_tensor_slices((img_path_list, mask_path_list))

def parse_function(img_path_list, mask_path_list):
    '''load image and mask from paths'''
    return img, mask

def data_augmentation(img, mask):
    '''process data with complex logic'''
    return aug_img, aug_mask

# py_func wrappers
def parse_function_wrapper(img_path_list, mask_path_list):
    return tf.py_func(func=parse_function,
                      inp=(img_path_list, mask_path_list),
                      Tout=(tf.float32, tf.float32))

def data_augmentation_wrapper(img, mask):
    return tf.py_func(func=data_augmentation,
                      inp=(img, mask),
                      Tout=(tf.float32, tf.float32))        

# Maps py_funcs to dataset
dataset = dataset.map(parse_function_wrapper,
                      num_parallel_calls=4)
dataset = dataset.map(data_augmentation_wrapper,
                      num_parallel_calls=4)

dataset = dataset.batch(32)
iter = dataset.make_one_shot_iterator()
imgs, labels = iter.get_next()

Однако из этого ответа кажется, что использование py_func для параллелизма не работает.Есть ли другая альтернатива?

1 Ответ

0 голосов
/ 12 июня 2018

py_func ограничен python GIL, поэтому параллелизма там не будет.Лучше всего записать добавление данных в собственном тензорном потоке (или предварительно вычислить его и сериализовать на диск).

Если вы хотите записать его в тензорном потоке, вы можете попробовать использовать tf.contrib.autograph дляконвертируйте простые циклы Python if и for в tf.conds и tf. while_loops, которые могут немного упростить ваш код.

...