Я хочу загрузить и дополнить пользовательский набор данных для сегментации. Для сегментации я подготовил файл npz
, содержащий четыре подмножества:
with np.load(PATH) as data:
train_x = data['x_train']
valid_x = data['x_valid']
train_y = data['y_train']
valid_y = data['y_valid']
Train / valid имеют соответствующие значения, а x / y обозначают входное изображение (x) и маску сегментации (y). В процессе обучения моя модель будет принимать входные данные x, а потери будут рассчитываться на основе выходных данных модели по отношению к y.
Теперь у меня вопрос: как на go вперед получить набор данных tf.data
, который я могу перебрать в обучении. Я пробовал следующее:
train_dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y))
>>> train_dataset
<TensorSliceDataset shapes: ((520, 696), (520, 696)), types: (tf.uint16, tf.uint8)>
def load(data_group):
image, mask = data_group
image = tf.cast(image, tf.float32)
mask = tf.cast(mask, tf.float32)
return image, mask
def normalize(image):
return (image / 65535/2) - 1
def load_image_train(data_group):
image, mask = load(data_group)
image = normalize(image)
# Perform augmentation (not shown)
return image, mask
train_dataset = train_dataset.map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(BATCH_SIZE)
Это, однако, не удается при попытке отобразить функцию поезда load_image_train
, возвращающую ошибку tf__load_image_train() takes 1 positional argument but 2 were given
. В целом такой подход выглядит немного неуклюжим, и ему хотелось бы узнать альтернативы / возможности для улучшения импорта данных.
Заранее спасибо