Как использовать `tf.scatter_nd` с многомерными тензорами - PullRequest
0 голосов
/ 10 июля 2019

Я пытаюсь создать новый тензор (output) со значениями другого тензора (updates), размещенного в соответствии с idx тензором. Форма output должна быть [batch_size, 1, 4, 4] (как изображение 2x2 пикселей и один канал), а update имеет форму [batch_size, 3].

Я прочитал документацию Tensorflow (я работаю с gpu версии 1.13.1) и обнаружил, что tf.scatter_nd должно работать для моей проблемы. Проблема в том, что я не могу заставить его работать, я думаю, у меня проблемы с пониманием того, как я должен расположить idx.

Давайте рассмотрим batch_size = 2, так что я делаю:

updates = tf.constant([[1, 2, 3], [4, 5, 6]])  # shape [2, 3]
output_shape = tf.constant([2, 1, 4, 4])
idx = tf.constant([[[1, 0], [1, 1], [1, 0]], [[0, 0], [0, 1], [0, 2]]])  # shape [2, 3, 2]
idx_expanded = tf.expand_dims(idx, 1)  # so I have shape [2, 1, 3, 2]
output = tf.scatter_nd(idx_expanded, updates, output_shape)

Я ожидаю, что это сработает, но это не так, мне выдается эта ошибка:

ValueError: The outer 3 dimensions of indices.shape=[2,1,3,2] must match the outer 3 dimensions of updates.shape=[2,3]: Shapes must be equal rank, but are 3 and 2 for 'ScatterNd_7' (op: 'ScatterNd') with input shapes: [2,1,3,2], [2,3], [4]

Я не понимаю, почему ожидается, что updates будет иметь размерность 3. Я подумал, что idx имеет смысл с output_shape (вот почему я использовал expand_dims), а также с updates (укажите два показателя для трех пунктов), но очевидно, что я что-то здесь упускаю.

Буду признателен за любую помощь.

1 Ответ

0 голосов
/ 11 июля 2019

Я играл с этой функцией, и я нашел свою ошибку.Если кто-то сталкивался с этой проблемой, вот что я сделал, чтобы решить ее:

Учитывая batch_size=2 и 3 точек, тензор idx должен иметь форму [2, 3, 4], где первое измерение соответствует партииоткуда мы берем значение update, второе измерение должно быть равно второму измерению updates (количество точек в пакете), а третье измерение - 4, потому что нам нужны 4 индексы: [batch_number, channel, ряд, столбец].Следуя примеру в вопросе:

updates = tf.constant([[1., 2., 3.], [4., 5., 6.]])  # [2, 3]
idx = tf.constant([[[0, 0, 0, 1], [0, 0, 0, 0], [0, 0, 1, 0]], [[1, 0, 1, 1], [1, 0, 0, 0], [1, 0, 1, 0]]])  # [2, 3, 4]
output = tf.scatter_nd(idx, updates, [2, 1, 4, 4])

sess = tf.Session()
print(sess.run(output))

[[[[2. 1. 0. 0.]
   [3. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]]]


 [[[5. 0. 0. 0.]
   [6. 4. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]]]]

Таким образом, можно поместить конкретные числа в новый тензор.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...