Понимание tf.scatter_nd_update: Как обновить значения столбцов? - PullRequest
0 голосов
/ 26 февраля 2019

Я пытаюсь перевести операцию NumPy по частичному обновлению в TensorFlow.я хочу воспроизвести следующий минимальный пример:

input = np.arange(3 * 5).reshape((3, 5))

array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14]])

input[:, [0, 2]] = -1

array([[-1,  1, -1,  3,  4],
       [-1,  6, -1,  8,  9],
       [-1, 11, -1, 13, 14]])

Итак, я хочу установить постоянное значение для всех элементов определенных столбцов в массиве.

Теперь у меня есть тензоры вместоМассивы NumPy, индексы столбцов также рассчитываются динамически и сохраняются в Tensors.Я нашел, как обновить все значения в данных строках , используя tf.scatter_nd_update:

input = tf.Variable(tf.reshape(tf.range(3 * 5, dtype=tf.int32), [3, 5]))                                                                                                                                                                                          
indices = tf.constant([[0], [2]])                                                                                                                                                                                                                                 
updates = tf.constant([[-1, -1, -1, -1, -1], [-1, -1, -1, -1, -1]])                                                                                                                                                                                               

scatter = tf.scatter_nd_update(input, indices, updates)                                                                                                                                                                                                           

with tf.Session() as sess:                                                                                                                                                                                                                                        
    sess.run(tf.global_variables_initializer())                                                                                                                                                                                                                   
    print(sess.run(scatter)) 

Вывод:

[[-1 -1 -1 -1 -1]
 [ 5  6  7  8  9]
 [-1 -1 -1 -1 -1]]

Но как я могу сделать это дляопределенные столбцы?

1 Ответ

0 голосов
/ 26 февраля 2019

Вы можете сделать это следующим образом:

import tensorflow as tf

def update_columns(variable, columns, value):
    columns = tf.convert_to_tensor(columns)
    rows = tf.range(tf.shape(variable)[0], dtype=columns.dtype)
    ii, jj = tf.meshgrid(rows, columns, indexing='ij')
    value = tf.broadcast_to(value, tf.shape(ii))
    return tf.scatter_nd_update(variable, tf.stack([ii, jj], axis=-1), value)

inp = tf.Variable(tf.reshape(tf.range(3 * 5, dtype=tf.int32), [3, 5]))
updated = update_columns(inp, [0, 2], -1)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(updated))

Вывод:

[[-1  1 -1  3  4]
 [-1  6 -1  8  9]
 [-1 11 -1 13 14]]

Обратите внимание, что вы должны использовать tf.scatter_nd_update только если вы действительно хотитеработать с переменной (и присвоить ей новое значение).Если вы хотите получить тензор, равный другому, но с некоторыми обновленными значениями, вы должны использовать обычные тензорные операции вместо преобразования его в переменную.Например, для этого случая вы можете сделать:

import tensorflow as tf

def update_columns_tensor(tensor, columns, value):
    columns = tf.convert_to_tensor(columns)
    shape = tf.shape(tensor)
    num_rows, num_columns = shape[0], shape[1]
    mask = tf.equal(tf.range(num_columns, dtype=columns.dtype), tf.expand_dims(columns, 1))
    mask = tf.tile(tf.expand_dims(tf.reduce_any(mask, axis=0), 0), (num_rows, 1))
    value = tf.broadcast_to(value, shape)
    return tf.where(mask, value, tensor)

inp = tf.reshape(tf.range(3 * 5, dtype=tf.int32), [3, 5])
updated = update_columns_tensor(inp, [0, 2], -1)
with tf.Session() as sess:
    print(sess.run(updated))
    # Same output
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...