Конвейер данных в tf.keras с tfrecords или numpy - PullRequest
0 голосов
/ 02 апреля 2019

Я хочу обучить модель в tf.keras из Tensorflow 2.0 с данными, которые больше, чем у моего плунжера, но в руководствах показаны только примеры с предопределенными наборами данных.

Я следовал этому уроку:

Загрузка изображений с tf.data , я не мог выполнить эту работу для данных на массивах numpy или tfrecords.

Это пример преобразования массива в наборы данных tenorflow.Я хочу, чтобы это работало для нескольких файлов массивов или нескольких файлов tfrecords.

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
# Shuffle and slice the dataset.
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)

# Since the dataset already takes care of batching,
# we don't pass a `batch_size` argument.
model.fit(train_dataset, epochs=3)

1 Ответ

0 голосов
/ 02 апреля 2019

Если у вас есть tfrecords файлы:

path = ['file1.tfrecords', 'file2.tfrecords', ..., 'fileN.tfrecords']
dataset = tf.data.Dataset.list_files(path, shuffle=True).repeat()
dataset = dataset.interleave(lambda filename: tf.data.TFRecordDataset(filename), cycle_length=len(path))
dataset = dataset.map(parse_function).batch()

parse_function обрабатывает декодирование и любое расширение.

В случае с массивами numpy, вы можете создать набор данных из спискаимена файлов или из списка массивов.Ярлыки - это просто список.Или они могут быть взяты из файла при разборе одного примера.

path = #list of numpy arrays

или

path = os.listdir(path_to files)

dataset = tf.data.Dataset.from_tensor_slices((path, labels))
dataset = dataset.map(parse_function).batch()

parse_function обрабатывает декодирование:

def parse_function(filename, label):  #Both filename and label will be passed if you provided both to from_tensor_slices
    f = tf.read_file(filename)
    image = tf.image.decode_image(f)) 
    image = tf.reshape(image, [H, W, C])
    label = label #or it could be extracted from, for example, filename, or from file itself 
    #do any augmentations here
    return image, label

Для декодирования файлов .npy,лучший способ - использовать reshape без read_file или decode_raw, но сначала загрузить numpys с помощью np.load:

paths = [np.load(i) for i in ["x1.npy", "x2.npy"]]
image = tf.reshape(filename, [2])

или попробовать использовать decode_raw

f = tf.io.read_file(filename)
image = tf.io.decode_raw(f, tf.float32)

Затем просто передайте пакетный набор данных на model.fit(dataset).TensorFlow 2.0 позволяет выполнять итерацию по набору данных.Не нужно использовать итератор.Даже в более поздних версиях 1.x API вы можете просто передать набор данных в .fit метод

for example in dataset:
    func(example)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...