Как получить доступ к тензорной форме внутри функции карты - PullRequest
0 голосов
/ 01 августа 2020

Мне нужно получить доступ к фигурам изображения для выполнения конвейера аугментации, хотя при доступе через image.shape[0] and image.shape[1] я не могу выполнять аугментации, поскольку он выводит, что мои тензоры имеют форму None. * Как получить доступ к тензорной форме в .map?

Благодарю, если кто-нибудь может помочь.

parsed_dataset = tf.data.TFRecordDataset(filenames=train_records_paths).map(parsing_fn) # Returns [image,label]
augmented_dataset = parsed_dataset.map(augment_pipeline) 
augmented_dataset = augmented_dataset.unbatch()

Сопоставленная функция

""" 
    Returns:
      5 Versions of the original image: 4 corner crops + a central crop and the respective labels.
"""
def augment_pipeline(original_image,label):
  central_crop = lambda image: tf.image.central_crop(image,0.5)
  corner_crops = lambda image: tf.image.extract_patches(images=tf.expand_dims(image,0), # Transform image in a batch of single sample
                                                sizes=[1, int(0.5 * image.shape[0]), int(0.5 * image.shape[1]), 1], # 50% of the image's height and width
                                                rates=[1, 1, 1, 1],
                                                strides=[1, int(0.5 * image.shape[0]), int(0.5 * image.shape[1]), 1],
                                                padding="SAME")
  reshaped_patches = tf.reshape(corner_crops(original_image), [-1,int(0.5*original_image.shape[0]),int(0.5*original_image.shape[1]),3])
  images = tf.concat([reshaped_patches,tf.expand_dims(central_crop(original_image),axis=0)],axis=0)
  label = tf.reshape(label,[1,1])
  labels = tf.tile(label,[5,1])
  return images,labels

Ответы [ 2 ]

1 голос
/ 02 августа 2020

После дальнейшего исследования мне удалось обойтись, используя py_func, как было предложено здесь и tf.shape(image)[0] здесь .

Код:

""" 
    Returns:
      5 Versions of the original image: 4 corner crops + a central crop and the respective labels.
"""
def augment_pipeline(original_image,label):
  height  = int(tf.shape(original_image)[0].numpy() * 0.5)  # 50% of the image's height and width
  width = int(tf.shape(original_image)[1].numpy() * 0.5)
  central_crop = lambda image: tf.image.central_crop(image,0.5)
  corner_crops = lambda image: tf.image.extract_patches(images=tf.expand_dims(image,0), # Transform image in a batch of single sample
                                                sizes=[1, height, width, 1],
                                                rates=[1, 1, 1, 1],
                                                strides=[1, height, width, 1],
                                                padding="SAME")

                                              .
                                              .
                                              .

Затем мы используем py_func, чтобы разрешить доступ к numpy значениям внутри функции карты:

parsed_dataset = tf.data.TFRecordDataset(filenames=train_records_paths).map(parsing_fn) # Returns [image,label]
augmented_dataset = parsed_dataset.map(lambda image,label: tf.py_function(func=augment_pipeline,
                                                                          inp=[image,label],
                                                                          Tout=[tf.float32,tf.int64])) 
augmented_dataset = augmented_dataset.unbatch()
0 голосов
/ 27 августа 2020

Каждый объект набора данных повторяется. Теперь объект набора данных может быть либо в пакетной форме, либо в неупакованной форме. Я расскажу вам, как получить форму их элементов в обоих случаях. его элементы с использованием iter

it = iter(dataset)
element = next(it)
image,label = element
## element is a tuple

Метод 2. с использованием take

element = dataset.take(1)
image,label = element
# element is a tuple

Случай 2. Когда набор данных группируется. Теперь я предполагаю, что набор данных содержит кортежи (изображение, метка)

Метод 1. Использование iter

it = iter(dataset)
batch = next(it)
images,labels = batch
## batch is a tuple check it using type(batch)

Метод 2. Использование take

batch = dataset.take(1)
## Note here each element of the dataset is a batch and each batch contains some number of 
## (image,label) tuples
batch = next(iter(batch))
images,labels = batch
## batch is again a tuple
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...