как применить функцию карты к tf.Tensor - PullRequest
1 голос
/ 29 мая 2020
dataset = tf.data.Dataset.from_tensor_slices((images,boxes))
function_to_map = lambda x,y: func3(x,y)
fast_benchmark(dataset.map(function_to_map).batch(1).prefetch(tf.data.experimental.AUTOTUNE))

Теперь я вот func3

def fast_benchmark(dataset, num_epochs=2):
    start_time = time.perf_counter()
    print('dataset->',dataset)
    for _ in tf.data.Dataset.range(num_epochs):
        for _,__ in dataset:
            print(_,__)
            break
            pass

вывод печати

tf.Tensor([b'/media/jake/mark-4tb3/input/datasets/pascal/VOCtrainval_11-May-2012/VOCdevkit/VOC2012/JPEGImages/2008_000008.jpg'], shape=(1,), dtype=string) <tf.RaggedTensor [[[52, 86, 470, 419], [157, 43, 288, 166]]]>

что я хочу сделать в func3 () хотите изменить каталог изображений на реальное изображение и запустить пакет

1 Ответ

0 голосов
/ 01 июня 2020

Вам нужно извлечь строку из тензора и использовать соответствующую функцию чтения изображения. Ниже приведены шаги, которые необходимо реализовать в коде для достижения этой цели.

  1. Вы должны украсить функцию карты с помощью tf.py_function(get_path, [x], [tf.float32]). Вы можете найти больше о tf.py_function здесь . В tf.py_function первый аргумент - это имя функции map, второй аргумент - это элемент, который должен быть передан функции map, а последний аргумент - это возвращаемый тип.
  2. Вы можете получить свою строковую часть, используя bytes.decode(file_path.numpy()) в функции карты.
  3. Используйте соответствующую функцию для загрузки изображения. Мы используем load_img.

В приведенной ниже простой программе мы используем tf.data.Dataset.list_files для чтения пути изображения. Затем в функции map мы читаем изображение, используя load_img, а затем выполняем функцию tf.image.central_crop для обрезки центральной части изображения.

Код -

%tensorflow_version 2.x
import tensorflow as tf
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array, array_to_img
from matplotlib import pyplot as plt
import numpy as np

def load_file_and_process(path):
    image = load_img(bytes.decode(path.numpy()), target_size=(224, 224))
    image = img_to_array(image)
    image = tf.image.central_crop(image, np.random.uniform(0.50, 1.00))
    return image

train_dataset = tf.data.Dataset.list_files('/content/bird.jpg')
train_dataset = train_dataset.map(lambda x: tf.py_function(load_file_and_process, [x], [tf.float32]))

for f in train_dataset:
  for l in f:
    image = np.array(array_to_img(l))
    plt.imshow(image)

Вывод -

enter image description here

Надеюсь, это ответит на ваш вопрос. Удачного обучения.

...