Загрузка данных сегментации с tf.data как набор данных? - PullRequest
0 голосов
/ 08 апреля 2020

Я хочу загрузить и дополнить пользовательский набор данных для сегментации. Для сегментации я подготовил файл npz, содержащий четыре подмножества:

with np.load(PATH) as data:
    train_x = data['x_train']
    valid_x = data['x_valid']
    train_y = data['y_train']
    valid_y = data['y_valid']

Train / valid имеют соответствующие значения, а x / y обозначают входное изображение (x) и маску сегментации (y). В процессе обучения моя модель будет принимать входные данные x, а потери будут рассчитываться на основе выходных данных модели по отношению к y.

Теперь у меня вопрос: как на go вперед получить набор данных tf.data, который я могу перебрать в обучении. Я пробовал следующее:

train_dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y))

>>> train_dataset
<TensorSliceDataset shapes: ((520, 696), (520, 696)), types: (tf.uint16, tf.uint8)>

def load(data_group):
    image, mask = data_group
    image = tf.cast(image, tf.float32)
    mask = tf.cast(mask, tf.float32)
    return image, mask

def normalize(image):
    return (image / 65535/2) - 1

def load_image_train(data_group):
    image, mask = load(data_group)
    image = normalize(image)
    # Perform augmentation (not shown)
    return image, mask

train_dataset = train_dataset.map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(BATCH_SIZE)

Это, однако, не удается при попытке отобразить функцию поезда load_image_train, возвращающую ошибку tf__load_image_train() takes 1 positional argument but 2 were given. В целом такой подход выглядит немного неуклюжим, и ему хотелось бы узнать альтернативы / возможности для улучшения импорта данных.

Заранее спасибо

1 Ответ

1 голос
/ 08 апреля 2020

Вы должны написать так:

def load_image_train(image,mask):

  image = tf.cast(image, tf.float32)
  mask = tf.cast(mask, tf.float32)
  image = normalize(image)

  return image, mask

tf.data.Dataset вернет пару тензоров в вашем случае.

Также ознакомьтесь с Tensorflow Guide

...