tf_gather_nd и tensor_scatter_nd_update в пакетах - PullRequest
1 голос
/ 17 июня 2020

Я борюсь с tf_gather_nd и tensor_scatter_nd_update.

Сначала: Я пытаюсь проиндексировать пакет тензоров. Тензор params имеет размерность (4, 64, 2), а мой исходный тензор индексов имеет размерность (4, 64). Мне удалось решить эту проблему, вставив индексы от 0 до 63 в мой исходный тензор индексов.

idx = np.array([i for i in range(64)])
indices_adj = tf.map_fn(lambda x: tf.stack([idx, x], axis=1), indices)
tf.gather_nd(params, indices_adj, batch_dims=1)

Однако мне интересно, есть ли лучшее решение, чем это?

Второй: Мне также нужно обновить пакет тензоров размерности (4, 64, 2) значениями в тензоре (4, 64) с индексами, предоставленными в тензоре (4, 64). Однако, поскольку tensor_scatter_nd_update не предоставляет никаких пакетных функций, таких как tf_gather_nd, я понятия не имею, как это эффективно реализовать.

До работы с пакетами мой код выглядел просто так: Я ценю любую помощь!

1 Ответ

0 голосов
/ 17 июня 2020

Первую операцию, которую вы хотите выполнить, можно просто выполнить следующим образом:

import tensorflow as tf

tf.random.set_seed(0)
params = tf.random.uniform([4, 64, 2])
idx = tf.random.uniform([4, 64], 0, 2, dtype=tf.int32)
out = tf.gather_nd(params, tf.expand_dims(idx, -1), batch_dims=2)
print(out.shape)
# (4, 64)

Для второй вам нужно построить полный многомерный индекс:

import tensorflow as tf

tf.random.set_seed(0)
params = tf.random.uniform([4, 64, 2])
idx = tf.random.uniform([4, 64], 0, 2, dtype=tf.int32)
update = tf.random.uniform([4, 64])
s = tf.shape(idx, out_type=idx.dtype)
ii, jj = tf.meshgrid(tf.range(s[0]), tf.range(s[1]), indexing='ij')
idx_comp = tf.stack([ii, jj, idx], axis=-1)
out = tf.tensor_scatter_nd_update(params, idx_comp, update)

Хотя в В конкретном случае, когда ваше последнее измерение имеет два элемента, вы также можете использовать эту эквивалентную операцию:

update_t = tf.tile(tf.expand_dims(update, axis=-1), [1, 1, 2])
idx_t = tf.stack([idx, 1 - idx], axis=-1)
out = tf.where(tf.dtypes.cast(idx_t, tf.bool), params, update_t)
...