Проблемы с функцией карты набора данных Tensorflow - PullRequest
0 голосов
/ 17 апреля 2020

Я использую tenorflow 2.1 для построения конвейера данных. Я написал функцию для предварительной обработки данных:

def preprocessing(path):
    path = str(path.numpy(), 'utf-8')
    label = Path(path).parent.name
    image = tf.io.read_file(path)
    image = tf.image.decode_image(image)
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    image = tf.image.central_crop(image, central_fraction=0.5)
    image = tf.image.resize(image, size=[224, 224])
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_brightness(image, max_delta=0.2)
    return image, label

Когда я проверяю функцию обработки, используя следующие коды, она работает.

ds = tf.data.Dataset.list_files('../datasets/hymenoptera_data/train/ants/*.jpg')
path = next(iter(ds))
image, label = preprocessing(path)
plt.imshow(image)
plt.show()

и результат печати (путь) - tf. Тензор (b '.. \ datasets \ hymenoptera_data \ train \ ants \ 886401651_f878e888cd.jpg', shape = (), dtype = string) Но если я использую map () для обработки сгенерированных ds, выдается ошибка:

ds_new = ds.map(preprocessing, num_parallel_calls=tf.data.experimental.AUTOTUNE)
for i in ds_new.take(1):
    plt.imshow(i)
    plt.show()

AttributeError: у объекта 'Tensor' нет атрибута 'numpy', эта ошибка произошла из-за path = str (path. numpy (), 'utf-8') в функции предварительной обработки.

Я не понимаю, почему, кто может помочь в этом вопросе, действительно ценю!

1 Ответ

0 голосов
/ 17 апреля 2020

Попробуйте эту функцию для предварительной обработки:

def preprocessing(path):
    label = tf.strings.split(path, os.path.sep)[-2]
    image = tf.io.read_file(path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    image = tf.image.central_crop(image, central_fraction=0.5)
    image = tf.image.resize(image, size=[224, 224])
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_brightness(image, max_delta=0.2)
    return image, label

Работает как с простой загрузкой, так и с tf.data:

import tensorflow as tf
import os
import matplotlib.pyplot as plt

paths = tf.data.Dataset.list_files('images/*.jpg')
path = next(iter(paths))
image, label = preprocessing(path)
plt.imshow(image)
plt.show()

filenames = tf.data.Dataset.list_files('images/*.jpg')
ds = filenames.map(preprocessing)
for image, label in ds.take(1):
    plt.imshow(image)
    plt.show()
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...