Как установить индексы для получения последнего измерения трехмерного тензора с помощью tf.tensor_scatter_nd_update? - PullRequest
1 голос
/ 27 апреля 2020

У меня есть 3D-тензор тензор_a, tensor_a = (510, 1, 6), и я хочу обновить tensor_a[..., 0] и tensor_a[...,-1], например tensor_a[..., 0] = 1. в pytorch или numpy. Как установить indices право на достижение того же результата, что и tensor_a[...,0] = 1. в pytorch?

1 Ответ

1 голос
/ 27 апреля 2020

Вы можете сделать это с помощью tf.tensor_scatter_nd_update следующим образом:

import tensorflow as tf

tensor_a = ...  # Some 3D tensor
idx_to_replace = 0
new_value = 1
s = tf.shape(tensor_a)
i1, i2 = tf.meshgrid(tf.range(s[0]), tf.range(s[1]), indexing='ij')
i3 = idx_to_replace * tf.ones_like(i1)
idx = tf.stack([i1, i2, i3], axis=-1)
updates = new_value * tf.ones_like(i1, dtype=tensor_a.dtype)
result = tf.tensor_scatter_nd_update(tensor_a, idx, updates)

Хотя это не работает с отрицательными индексами, вам нужно сделать его положительным, например с помощью:

idx_to_replace = tf.cond(tf.less(idx_to_replace, 0),
                         lambda: idx_to_replace + s[-1],
                         lambda: idx_to_replace)

Однако, чтобы заменить первый индекс последнего измерения на единицы, вам может быть проще и быстрее сделать что-то вроде этого:

result = tf.concat([tf.ones_like(tensor_a[..., :1]), tensor_a[..., 1:]], axis=-1)

Аналогично для последнего измерения :

result = tf.concat([tensor_a[..., :-1], tf.ones_like(tensor_a[..., -1:])], axis=-1)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...