XLA-совместимый динамический c нарезка - PullRequest
0 голосов
/ 29 мая 2020

Есть ли способ динамически разрезать тензор в соответствии с генератором случайных чисел в функции, скомпилированной в XLA? Например:

@tf.function(experimental_compile=True)
def random_slice(input, max_slice_size):
    offset = tf.squeeze(tf.random.uniform([1], minval=0, maxval=input.shape[0]-max_slice_size, dtype=tf.int32))
    sz = tf.squeeze(tf.random.uniform([1], minval=1, maxval=max_slice_size, dtype=tf.int32))

    indices = tf.range(offset, offset+sz)  # Non-XLA-able due to non-static bounds

    return tf.gather(input, indices)

x = tf.ones([50, 50])
y = random_slice(x, 4)

Этот код не может быть скомпилирован, потому что XLA требует, чтобы аргументы tf.range были известны во время компиляции. Есть ли рекомендуемый обходной путь?

1 Ответ

0 голосов
/ 30 мая 2020

Основная проблема здесь в том, что XLA необходимо статически знать формы всех Tensor в программе. В этом случае он жалуется на tf.range, потому что его результат не известен с учетом случайных входов. Вместо этого вы могли бы уйти с генерации замаскированной версии (обнуление элементов, которые вам не нужны, используя что-то вроде tensor_scatter_nd_update) и использование этой замаскированной версии ниже по течению (трудно сказать точно, как, не видя большего контекста того, как y будет использоваться).

...