Сглаживание кортежа изображений в наборе данных tenorflow - PullRequest
3 голосов
/ 11 июля 2019

У меня есть набор данных триплетных изображений, которые я читаю из tfrecords, которые я преобразовал в набор данных, используя следующий код

    def parse_dataset(record):
        def convert_raw_to_image_tensor(raw):
            raw = tf.io.decode_base64(raw)
            image_shape = tf.stack([299, 299, 3])
            decoded = tf.io.decode_image(raw, channels=3, 
                                dtype=tf.uint8, expand_animations=False)
            decoded = tf.cast(decoded, tf.float32)
            decoded = tf.reshape(decoded, image_shape)
            decoded = tf.math.divide(decoded, 255.)
            return decoded

        features = {
            'n': tf.io.FixedLenFeature([], tf.string),
            'p': tf.io.FixedLenFeature([], tf.string),
            'q': tf.io.FixedLenFeature([], tf.string)
        }
        sample = tf.io.parse_single_example(record, features)
        neg_image = sample['n']
        pos_image = sample['p']
        query_image = sample['q']

        neg_decoded = convert_raw_to_image_tensor(neg_image)
        pos_decoded = convert_raw_to_image_tensor(pos_image)
        query_decoded = convert_raw_to_image_tensor(query_image)
        return (neg_decoded, pos_decoded, query_decoded)

    record_dataset = tf.data.TFRecordDataset(filenames=path_dataset, num_parallel_reads=4)
    record_dataset = record_dataset.map(parse_dataset)

Форма этого результирующего набора данных

<MapDataset shapes: ((299, 299, 3), (299, 299, 3), (299, 299, 3)), types: (tf.float32, tf.float32, tf.float32)>

, что, я думаю, означает, что каждая запись содержит 3 изображения (что я подтвердил, просматривая набор данных и печатая 1-й, 2-й и 3-й элементы). Я хочу сгладить это, поэтому я получаю набор данных, который не содержит никаких кортежей, а просто плоский список изображений. Я попытался использовать flat_map, но он просто конвертирует изображения в (299, 3), и я попытался выполнить итерацию по набору данных, добавив каждое изображение в список, затем вызвав convert_to_tensor_slices, но это действительно неэффективно.

Я прочитал этот вопрос , но, похоже, это не помогло.

Кстати, это код flat_map, который я пробовал

record_dataset = record_dataset.flat_map(lambda *x: tf.data.Dataset.from_tensor_slices(x))

и полученный набор данных имеет эту форму

<FlatMapDataset shapes: ((299, 3), (299, 3), (299, 3)), types: (tf.float32, tf.float32, tf.float32)>

1 Ответ

1 голос
/ 11 июля 2019

Я думаю, вы просто неправильно распаковываете кортеж.

это должно сделать это:

def flatten(*x):
  return tf.data.Dataset.from_tensor_slices([i for i in x])

flattened = record_dataset.flat_map(flatten)

так что:

for i in flattened:
  print(i.shape)

дает:

(299, 299, 3)
(299, 299, 3)
(299, 299, 3)
(299, 299, 3)
...

как и ожидалось

...