Tensorflow - Pad OR Truncate Sequence с набором данных API - PullRequest
0 голосов
/ 04 сентября 2018

Я пытаюсь использовать API набора данных для подготовки TFRecordDataset текстовых последовательностей. После обработки у меня есть словарь тензоров для каждой записи. Каждая запись содержит две последовательности.

Я использую padded_batch, чтобы применить отступ

dataset = dataset.padded_batch(batch_size, padded_shapes= {
    'seq1': tf.TensorShape([None]),
    'seq2': tf.TensorShape([None])
})

Это дополняет каждую последовательность до максимальной длины последовательности в пакете. Однако я хотел бы выбрать произвольную длину последовательности и заполнить ее этой длиной, когда истинная длина последовательности меньше, иначе обрезать последовательность.

Когда я пытаюсь заменить None на 100, например, я сталкиваюсь с DataLossError

DataLossError: Попытка заполнения до меньшего размера, чем элемент ввода.

Есть ли способ сделать это для достижения функциональности, аналогичной tf.image.resize_image_with_crop_or_pad в последовательности?

1 Ответ

0 голосов
/ 05 сентября 2018

Нет простого способа дополнить или усечь, но вы можете использовать функцию map, чтобы получить набор данных, содержащий элементы с желаемой длиной. Вот быстрый пример:

k = 4
def pad_or_trunc(t):
    dim = tf.size(t)
    return tf.cond(tf.equal(dim, k), lambda: t, lambda: tf.cond(tf.greater(dim, k), lambda: tf.slice(t, [0], [k]), lambda: tf.concat([t, tf.zeros(k-dim, dtype=tf.int32)], 0)))

vals = tf.constant([[1, 1, 1], [2, 2, 2], [3, 3, 3]])
dset1 = tf.data.Dataset.from_tensor_slices(vals)
dset2 = dset1.map(pad_or_trunc)
iter = dset2.make_one_shot_iterator()

with tf.Session() as sess:
    while True:
        try:
            print(sess.run(iter.get_next()))
        except tf.errors.OutOfRangeError:
            break
...