Кажется, есть много ограничений относительно того, что XLA может делать с рваными тензорами. Есть несколько альтернатив, которые, я думаю, могут заставить ваш пример работать, но я не знаю, будут ли они применимы к вашему реальному варианту использования. С одной стороны, вы могли бы заранее суммировать по неровным измерениям или даже по всем измерениям, кроме первого в вашем случае. Это, однако, должно быть сделано за пределами XLA, так как кажется, что он не может скомпилировать его:
import tensorflow as tf
x = tf.reshape(tf.range(25, dtype=tf.float32), [5, 5])
row_lengths = tf.constant([2, 1, 2])
ragged_tensor = tf.RaggedTensor.from_row_lengths(x, row_lengths)
# Sum in advance
ragged_sum = tf.reduce_sum(ragged_tensor, axis=[1, 2])
@tf.function(autograph=False, experimental_compile=True)
def while_loop_works():
num_rows = ragged_tensor.nrows()
def cond(i, _):
return i < num_rows
def body(i, running_total):
# Use the sums computed before
return i + 1, running_total + ragged_sum[i]
_, total = tf.while_loop(cond, body, [0, 0.0])
return total
result = while_loop_works()
print(result.numpy())
# 300.0
Вы также можете просто преобразовать рваный тензор в обычный тензор, который дополнит его нули, которые не повлияют на вашу сумму. Опять же, это в настоящее время должно быть сделано из XLA:
import tensorflow as tf
x = tf.reshape(tf.range(25, dtype=tf.float32), [5, 5])
row_lengths = tf.constant([2, 1, 2])
ragged_tensor = tf.RaggedTensor.from_row_lengths(x, row_lengths)
# Convert into a regular tensor
unragged_tensor = ragged_tensor.to_tensor()
@tf.function(autograph=False, experimental_compile=True)
def while_loop_works():
num_rows = ragged_tensor.nrows()
def cond(i, _):
return i < num_rows
def body(i, running_total):
# Reduce padded tensor
return i + 1, running_total + tf.reduce_sum(unragged_tensor[i])
_, total = tf.while_loop(cond, body, [0, 0.0])
return total
result = while_loop_works()
print(result.numpy())
# 300.0