Tensorflow 2.0: назначение обновлений для тензора на основе маски - PullRequest
0 голосов
/ 22 апреля 2019

У меня небольшая проблема с повторной реализацией скрипта Pytorch с помощью нового API Tensorflow 2.0 (или Tensorflow в целом). В оригинальном скрипте Pytorch тензор обновляется с использованием логической маски (аналогично тому, как это делается в numpy):

# PyTorch
x_pred[empty_mask] = s_

x_pred - тензор с формой (batch_size, 81, 9). Этот тензор должен обновляться несколько раз во время шага вперед сети.

s_ - тензор, содержащий вероятности softmax, генерируемые нейронной сетью. Форма зависит от количества пустых полей в empty_mask. Первое измерение является переменным, второе измерение всегда равно 9.

empty_mask - тензор с формой (batch_size, 81). Этот тензор верен для пустых полей в x_pred. Эти пустые поля должны обновляться при каждом шаге вперед.

В данный момент я могу извлечь соответствующие части тензора x_pred, используя маску.

extract = x_pred[empty_mask]

extract и x_pred[empty_mask] имеют одинаковую форму.

Когда я пытаюсь обновить тензор Tensorflow в том же стиле, я получаю следующее сообщение об ошибке:

# Tensorflow 2.0 
x_pred = tf.Variable(x, trainable=False)   # x is the input, so x_pred should be a copy an will be filled in the next steps

...

for i in range(max_empty_fields):

   ...
   s_ = self.softmax(previous_layers)  # Shape (???, 9)
   x_pred[empty_fields_mask] = s_      # Update x_pred

   ...


# ==> TypeError: only size-1 arrays can be converted to Python scalars

Может кто-нибудь показать мне, как назначить эти обновления "на месте", чтобы x_pred обновлялся напрямую?

Большое спасибо.

UPDATE

Я нашел решение, которое работает для меня. Сначала я должен перевести тензор маскирования в соответствующие индексы x_pred, затем я могу использовать функцию tf.tensor_scatter_nd_update () , предоставляемую новым API TF2.

# Translate mask into indices using tf.where()
indices = tf.where(empty_fields_mask)

# Update the given indices of x_pred using tf.tensor_scatter_nd_update()
x_pred = tf.tensor_scatter_nd_update(x_pred, indices, s_)
...