Как получить форму тензора из TFRecordDataset - PullRequest
0 голосов
/ 24 августа 2018

У меня есть следующая особенность, написанная для моей тренировки TFRecord:

feature = {'label': _int64_feature(gt),
           'image': _bytes_feature(tf.compat.as_bytes(im.tostring())),
           'height': _int64_feature(h),
           'width': _int64_feature(w)}

и я читаю это как:

train_dataset = tf.data.TFRecordDataset(train_file)
train_dataset = train_dataset.map(parse_func)
train_dataset = train_dataset.shuffle(buffer_size=1)
train_dataset = train_dataset.batch(batch_size)
train_dataset = train_dataset.prefetch(batch_size)

тогда как мой parse_func выглядит так:

def parse_func(ex):
    feature = {'image': tf.FixedLenFeature([], tf.string),
               'label': tf.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
               'height': tf.FixedLenFeature([], tf.int64),
               'width': tf.FixedLenFeature([], tf.int64)}
    features = tf.parse_single_example(ex, features=feature)
    image = tf.decode_raw(features['image'], tf.uint8)
    height = tf.cast(features['height'], tf.int32)
    width = tf.cast(features['width'], tf.int32)
    im_shape = tf.stack([width, height])
    image = tf.reshape(image, im_shape)
    label = tf.cast(features['label'], tf.int32)
    return image, label

Теперь я хочу получить форму image и label вроде:

image.get_shape().as_list()

который печатает
[Нет, Нет, Нет]
вместо
[Нет, 224, 224] (размер изображения (серия, ширина, высота))

Есть ли какая-нибудь функция, которая может дать мне размер этих тензоров?

1 Ответ

0 голосов
/ 24 августа 2018

Поскольку ваша функция карты "parse_func" является частью графа как операция, и она не знает фиксированный размер вашего ввода и априори помечает, использование get_shape () не вернет ожидаемую фиксированную форму.

Если ваше изображение, форма ярлыка исправлена, как хак, вы можете попытаться изменить свое изображение, ярлыки с уже известным размером (это фактически ничего бы не сделало, но явно установит размервыходные тензоры).

отл.image = tf.reshape (image, [224,224])

С этим вы сможете получить результат get_shape (), как и ожидалось.

...