Понимание того, как использовать tf.dataset.map () - PullRequest
1 голос
/ 02 апреля 2020

Я конвертирую некоторый код, который первоначально использовал JPEG в качестве входных данных для использования файлов Matlab MAT. Код содержит строки:

train_dataset = tf.data.Dataset.list_files(PATH + 'train/*.mat')
train_dataset = train_dataset.shuffle(BUFFER_SIZE) 
train_dataset = train_dataset.map(load_image_train)

Если I l oop через набор данных и распечатать () каждый элемент перед map (), я получу набор тензоров с видимыми путями к файлам.

Однако в функции load_image_train это не так, вывод print ():

Tensor("add:0", shape=(), dtype=string)

Я хотел бы использовать scipy.io.loadmat () функция для получения данных из моих файлов матов, но она не работает, потому что путь является тензором, а не строкой. Что делает dataset.map (), что делает буквенное строковое значение больше не видимым? Как извлечь строку, чтобы я мог использовать ее в качестве входных данных для scipy.io.loadmat ()?

Извинения, если это глупый вопрос, относительно новый для Tensorflow и все еще пытающийся понять. Много дискуссий, которые я могу найти по связанным вопросам, относится только к TF v1. Спасибо за любую помощь!

1 Ответ

0 голосов
/ 08 мая 2020

В приведенном ниже коде я использую tf.data.Dataset.list_files для чтения file_path изображения. В функции map я загружаю изображение и выполняю crop_central (в основном кадрирует центральную часть изображения для заданного процента, здесь я указал процент np.random.uniform(0.50, 1.00)).

Как вы правильно заметили, трудно прочитать файл, поскольку путь к файлу имеет тип tf.string, а для load_img или любой другой функции для чтения файла изображения потребуется простой тип string .

Итак, вот как вы можете это сделать -

  1. Вам нужно украсить свою функцию карты с помощью tf.py_function(load_file_and_process, [x], [tf.float32]). Вы можете узнать больше об этом здесь .
  2. Вы можете извлечь string из tf.string, используя bytes.decode(path.numpy().

Ниже приведен полный код для справки. Вы можете заменить его на путь к изображению во время выполнения этого кода.

%tensorflow_version 2.x
import tensorflow as tf
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array, array_to_img
from matplotlib import pyplot as plt
import numpy as np

def load_file_and_process(path):
    image = load_img(bytes.decode(path.numpy()), target_size=(224, 224))
    image = img_to_array(image)
    image = tf.image.central_crop(image, np.random.uniform(0.50, 1.00))
    return image

train_dataset = tf.data.Dataset.list_files('/content/bird.jpg')
train_dataset = train_dataset.map(lambda x: tf.py_function(load_file_and_process, [x], [tf.float32]))

for f in train_dataset:
  for l in f:
    image = np.array(array_to_img(l))
    plt.imshow(image)

Вывод -

enter image description here

Надеюсь, что это ответ на ваш вопрос. Счастливого обучения.

...