Тензор потока: извлечение последовательных патчей из сложного тензора произвольной длины - PullRequest
3 голосов
/ 17 января 2020

Я пытаюсь выяснить, как извлечь последовательные патчи из сложного многозначного тензора, где длина является переменной. Извлечение выполняется как часть конвейера tf.data.

Если бы тензор не был сложным, я бы использовал tf.image.extract_image_patches, как в этом ответе .

Однако эта функция не работает со сложными тензорами. Я попробовал следующую технику, но она не работает, потому что длина тензора неизвестна.

def extract_sequential_patches(image):
    image_length = tf.shape(image)[0]
    num_patches = image_length // (128 // 4)
    patches = []
    for i in range(num_patches):
        start = i * 128
        end = start + 128
        patches.append(image[start:end, ...])
    return tf.stack(patches)

Однако я получаю ошибку:

InaccessibleTensorError: The tensor 'Tensor("strided_slice:0", shape=(None, 512, 2), dtype=complex64)' cannot be accessed here: it is defined in another function or code block. Use return values, explicit Python locals or TensorFlow collections to access it. Defined in: FuncGraph(name=while_body_2100, id=140313967335120)

Я пробовал либеральное оформление с помощью @tf.function

1 Ответ

1 голос
/ 22 января 2020

Я думаю, вам нужно будет скорректировать расчет индексов, чтобы убедиться, что они не выходят за пределы go, но, оставив эту деталь в стороне, ваш код - почти то, что tf.function ожидает, за исключением использования из списка Python; вместо этого вам нужно использовать TensorArray.

Что-то вроде этого должно работать (вычисления индекса могут быть не совсем правильными):

@tf.function
def extract_sequential_patches(image, size, stride):
    image_length = tf.shape(image)[0]
    num_patches = (image_length - size) // stride + 1
    patches = tf.TensorArray(image.dtype, size=num_patches)
    for i in range(num_patches):
        start = i * stride
        end = start + size
        patches = patches.write(i, image[start:end, ...])
    return patches.stack()

Вы можете найти более подробную информацию о том, почему Python перечисляет дон в настоящее время не работает в справочных документах автографа .

Тем не менее, может быть быстрее использовать трюк real / imag, если ядро ​​extract_image_patches оптимизировано. Я рекомендую протестировать оба подхода.

...