Я хочу присвоить некоторые значения срезам входного тензора в одной из моих моделей в TensorFlow 2.x (я использую 2.2, но готов принять решение для 2.1). Нерабочий шаблон того, что я пытаюсь сделать, это:
import tensorflow as tf
from tensorflow.keras.models import Model
class AddToEven(Model):
def call(self, inputs):
outputs = inputs
outputs[:, ::2] += inputs[:, ::2]
return outputs
Конечно, при создании этого (AddToEven().build(tf.TensorShape([None, None]))
) я получаю следующую ошибку:
TypeError: 'Tensor' object does not support item assignment
Я могу достичь этого простого примера с помощью следующего:
class AddToEvenScatter(Model):
def call(self, inputs):
batch_size = tf.shape(inputs)[0]
n = tf.shape(inputs)[-1]
update_indices = tf.range(0, n, delta=2)[:, None]
scatter_nd_perm = [1, 0]
inputs_reshaped = tf.transpose(inputs, scatter_nd_perm)
outputs = tf.tensor_scatter_nd_add(
inputs_reshaped,
indices=update_indices,
updates=inputs_reshaped[::2],
)
outputs = tf.transpose(outputs, scatter_nd_perm)
return outputs
(вы можете проверить работоспособность с помощью:
model = AddToEvenScatter()
model.build(tf.TensorShape([None, None]))
model(tf.ones([1, 10]))
)
Но, как вы можете видеть, это очень сложно написать. И это только для статического c количества обновлений (здесь 1) на тензоре 1D (+ размер пакета).
То, что я хочу сделать, требует более сложного подхода, и я думаю, что напишу его с помощью tensor_scatter_nd_add
будет кошмаром.
Многие текущие QA на topi c охватывают случай переменных, но не тензоров (см., Например, this или this ). Здесь упоминается здесь , что pytorch действительно поддерживает это, поэтому я удивлен, что в последнее время не вижу ответа от каких-либо членов tf на этот c topi. Этот ответ мне не очень помогает, потому что мне понадобится какая-то генерация маски, которая тоже будет ужасной.
Вопрос в следующем: как я могу выполнить назначение срезов эффективно (с точки зрения вычислений, памяти и кода) без tensor_scatter_nd_add
? Хитрость в том, что я хочу, чтобы это было как можно более динамично, а это означает, что форма inputs
может быть переменной.
(Для всех, кому интересно, я пытаюсь перевести этот код in tf).
Этот вопрос изначально был размещен в выпуске GitHub .