У меня есть набор данных триплетных изображений, которые я читаю из tfrecords, которые я преобразовал в набор данных, используя следующий код
def parse_dataset(record):
def convert_raw_to_image_tensor(raw):
raw = tf.io.decode_base64(raw)
image_shape = tf.stack([299, 299, 3])
decoded = tf.io.decode_image(raw, channels=3,
dtype=tf.uint8, expand_animations=False)
decoded = tf.cast(decoded, tf.float32)
decoded = tf.reshape(decoded, image_shape)
decoded = tf.math.divide(decoded, 255.)
return decoded
features = {
'n': tf.io.FixedLenFeature([], tf.string),
'p': tf.io.FixedLenFeature([], tf.string),
'q': tf.io.FixedLenFeature([], tf.string)
}
sample = tf.io.parse_single_example(record, features)
neg_image = sample['n']
pos_image = sample['p']
query_image = sample['q']
neg_decoded = convert_raw_to_image_tensor(neg_image)
pos_decoded = convert_raw_to_image_tensor(pos_image)
query_decoded = convert_raw_to_image_tensor(query_image)
return (neg_decoded, pos_decoded, query_decoded)
record_dataset = tf.data.TFRecordDataset(filenames=path_dataset, num_parallel_reads=4)
record_dataset = record_dataset.map(parse_dataset)
Форма этого результирующего набора данных
<MapDataset shapes: ((299, 299, 3), (299, 299, 3), (299, 299, 3)), types: (tf.float32, tf.float32, tf.float32)>
, что, я думаю, означает, что каждая запись содержит 3 изображения (что я подтвердил, просматривая набор данных и печатая 1-й, 2-й и 3-й элементы). Я хочу сгладить это, поэтому я получаю набор данных, который не содержит никаких кортежей, а просто плоский список изображений. Я попытался использовать flat_map, но он просто конвертирует изображения в (299, 3), и я попытался выполнить итерацию по набору данных, добавив каждое изображение в список, затем вызвав convert_to_tensor_slices, но это действительно неэффективно.
Я прочитал этот вопрос , но, похоже, это не помогло.
Кстати, это код flat_map, который я пробовал
record_dataset = record_dataset.flat_map(lambda *x: tf.data.Dataset.from_tensor_slices(x))
и полученный набор данных имеет эту форму
<FlatMapDataset shapes: ((299, 3), (299, 3), (299, 3)), types: (tf.float32, tf.float32, tf.float32)>