XLA не может вывести форму вывода постоянной времени компиляции для выделенного среза при использовании рваного тензора и цикла while - PullRequest
5 голосов
/ 13 апреля 2020

Можно ли получить следующий минимальный пример, работающий с experimental_compile=True? Я видел некоторые большие ускорения с этим аргументом, поэтому я стремлюсь выяснить, как заставить его работать. Спасибо!

import tensorflow as tf

print(tf.__version__)
# ===> 2.2.0-dev20200409

x = tf.reshape(tf.range(25, dtype=tf.float32), [5, 5])
row_lengths = tf.constant([2, 1, 2])
ragged_tensor = tf.RaggedTensor.from_row_lengths(x, row_lengths)

for i, tensor in enumerate(ragged_tensor):
    print(f"i: {i}\ntensor:\n{tensor}\n")
# ==>
# i: 0
# tensor:
# [[0. 1. 2. 3. 4.]
#  [5. 6. 7. 8. 9.]]

# i: 1
# tensor:
# [[10. 11. 12. 13. 14.]]

# i: 2
# tensor:
# [[15. 16. 17. 18. 19.]
#  [20. 21. 22. 23. 24.]]


@tf.function(autograph=False, experimental_compile=True)
def while_loop_fail():

    num_rows = ragged_tensor.nrows()

    def cond(i, _):
        return i < num_rows

    def body(i, running_total):
        return i + 1, running_total + tf.reduce_sum(ragged_tensor[i])

    _, total = tf.while_loop(cond, body, [0, 0.0])

    return total


while_loop_fail()
# ===>
# tensorflow.python.framework.errors_impl.InvalidArgumentError: XLA can't deduce compile time constant output shape for strided slice: [?,5], output shape must be a compile-time constant
#    [[{{node while/RaggedGetItem/strided_slice_4}}]]
#    [[while]]
#   This error might be occurring with the use of xla.compile. If it is not necessary that every Op be compiled with XLA, an alternative is to use auto_jit with OptimizerOptions.global_jit_level = ON_2 or the environment variable TF_XLA_FLAGS="tf_xla_auto_jit=2" which will attempt to use xla to compile as much of the graph as the compiler is able to. [Op:__inference_while_loop_fail_481]

1 Ответ

4 голосов
/ 01 мая 2020

Кажется, есть много ограничений относительно того, что XLA может делать с рваными тензорами. Есть несколько альтернатив, которые, я думаю, могут заставить ваш пример работать, но я не знаю, будут ли они применимы к вашему реальному варианту использования. С одной стороны, вы могли бы заранее суммировать по неровным измерениям или даже по всем измерениям, кроме первого в вашем случае. Это, однако, должно быть сделано за пределами XLA, так как кажется, что он не может скомпилировать его:

import tensorflow as tf

x = tf.reshape(tf.range(25, dtype=tf.float32), [5, 5])
row_lengths = tf.constant([2, 1, 2])
ragged_tensor = tf.RaggedTensor.from_row_lengths(x, row_lengths)

# Sum in advance
ragged_sum = tf.reduce_sum(ragged_tensor, axis=[1, 2])

@tf.function(autograph=False, experimental_compile=True)
def while_loop_works():

    num_rows = ragged_tensor.nrows()

    def cond(i, _):
        return i < num_rows

    def body(i, running_total):
        # Use the sums computed before
        return i + 1, running_total + ragged_sum[i]

    _, total = tf.while_loop(cond, body, [0, 0.0])

    return total


result = while_loop_works()
print(result.numpy())
# 300.0

Вы также можете просто преобразовать рваный тензор в обычный тензор, который дополнит его нули, которые не повлияют на вашу сумму. Опять же, это в настоящее время должно быть сделано из XLA:

import tensorflow as tf

x = tf.reshape(tf.range(25, dtype=tf.float32), [5, 5])
row_lengths = tf.constant([2, 1, 2])
ragged_tensor = tf.RaggedTensor.from_row_lengths(x, row_lengths)

# Convert into a regular tensor
unragged_tensor = ragged_tensor.to_tensor()

@tf.function(autograph=False, experimental_compile=True)
def while_loop_works():
    num_rows = ragged_tensor.nrows()

    def cond(i, _):
        return i < num_rows

    def body(i, running_total):
        # Reduce padded tensor
        return i + 1, running_total + tf.reduce_sum(unragged_tensor[i])

    _, total = tf.while_loop(cond, body, [0, 0.0])

    return total


result = while_loop_works()
print(result.numpy())
# 300.0
...