Как выполнить обновление тензорного среза, как в pytorch? - PullRequest
0 голосов
/ 04 октября 2018

В Pytorch вы можете легко обновить тензор, как показано ниже:

 for i in range(x_len):
     tensor_abc[:, i, i] = 0

Как мы можем обновить тензор, подобный этому, в кодировании тензорного потока?Я пробовал tf.assign, не могу обновить срез .. пробовал tf.scatter_update, тоже не работает ...

Ответы [ 2 ]

0 голосов
/ 04 октября 2018

tf.Variable s являются единственными тензорами, которые могут быть обновлены (https://www.tensorflow.org/guide/variables). С переменными, вы будете использовать код, такой как gather и scatter_update для нарезки.

Обратите внимание, что другие тензорыне поддаются присвоению. Если это то, что вы пытаетесь сделать, я бы удивился, почему это необходимо. Тем не менее, все еще можно создавать новые тензоры со значениями, которые вы хотите (вместо назначения на месте), с кодом, который являетсянемного запутанный. Например, следующее не работает:

index = ... tensor = tf.constant([0,1,2,3,4]) 
tensor[i] = 0  
## Doesn't work (TypeError: `Tensor` object does not support item assignment)

Но это может сделать эквивалент:

tensor = tf.constant([0,1,2,3,4]) 
tensor = tf.concat([tensor[:i], tf.zeros_like(tensor[i:i+1]), tensor[i+1:]], 0)  
## This works, creates a new tensor

ИЛИ:

tensor = tf.constant([0,1,2,3,4]) 
tensor = tf.concat([tensor[:i], tf.fill([1], 0), tensor[i+1:]], 0)  
## This works, creates a new tensor
0 голосов
/ 04 октября 2018

Этот ответ относится только к переменным.

import tensorflow as tf

sess = tf.InteractiveSession()
v = tf.zeros((5,5,5))
var = tf.Variable(initial_value=v)


init = tf.variables_initializer([var])
sess.run(init)


var = var[ 1 : 2 ,
           1 : 2 ,
           1 : 2 ].assign(tf.ones((1,1,1)))

print(sess.run(var))

Это производит

[[[0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0.]
  [0. 1. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]]

И это

var = var[ 1 : 2 ,
           0 : 1 ,
           0 : 1 ].assign(tf.ones((1,1,1)))

производит

  [[[0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]

 [[1. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]

  ....
  ....]]

Другой пример:

var = var[ 1 : 2 ,
             : 2 ,
             : 2 ].assign(tf.ones((1,2,2)))

[[[0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]

 [[1. 1. 0. 0. 0.]
  [1. 1. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]

      ....
      ....]]

Вам следует изучить tf.scatter_nd для тензоров.

...