У меня небольшая проблема с повторной реализацией скрипта Pytorch с помощью нового API Tensorflow 2.0 (или Tensorflow в целом). В оригинальном скрипте Pytorch тензор обновляется с использованием логической маски (аналогично тому, как это делается в numpy):
# PyTorch
x_pred[empty_mask] = s_
x_pred
- тензор с формой (batch_size, 81, 9). Этот тензор должен обновляться несколько раз во время шага вперед сети.
s_
- тензор, содержащий вероятности softmax, генерируемые нейронной сетью. Форма зависит от количества пустых полей в empty_mask
. Первое измерение является переменным, второе измерение всегда равно 9.
empty_mask
- тензор с формой (batch_size, 81). Этот тензор верен для пустых полей в x_pred
. Эти пустые поля должны обновляться при каждом шаге вперед.
В данный момент я могу извлечь соответствующие части тензора x_pred
, используя маску.
extract = x_pred[empty_mask]
extract
и x_pred[empty_mask]
имеют одинаковую форму.
Когда я пытаюсь обновить тензор Tensorflow в том же стиле, я получаю следующее сообщение об ошибке:
# Tensorflow 2.0
x_pred = tf.Variable(x, trainable=False) # x is the input, so x_pred should be a copy an will be filled in the next steps
...
for i in range(max_empty_fields):
...
s_ = self.softmax(previous_layers) # Shape (???, 9)
x_pred[empty_fields_mask] = s_ # Update x_pred
...
# ==> TypeError: only size-1 arrays can be converted to Python scalars
Может кто-нибудь показать мне, как назначить эти обновления "на месте", чтобы x_pred
обновлялся напрямую?
Большое спасибо.
UPDATE
Я нашел решение, которое работает для меня. Сначала я должен перевести тензор маскирования в соответствующие индексы x_pred
, затем я могу использовать функцию tf.tensor_scatter_nd_update () , предоставляемую новым API TF2.
# Translate mask into indices using tf.where()
indices = tf.where(empty_fields_mask)
# Update the given indices of x_pred using tf.tensor_scatter_nd_update()
x_pred = tf.tensor_scatter_nd_update(x_pred, indices, s_)