как сделать назначение слайса, в то время как сам слайс является тензорным в тензорном потоке - PullRequest
1 голос
/ 19 июня 2019

Я хочу сделать назначение среза в тензорном потоке.Я узнал, что могу использовать:

my_var = my_var[4:8].assign(tf.zeros(4))

на основе этой ссылки .

, как вы видите в my_var[4:8] у нас есть конкретные индексы 4, 8 здесьдля нарезки, а затем присваивания.

Мой случай отличается, я хочу сделать нарезку на основе тензора, а затем выполнить присваивание.

out = tf.Variable(tf.zeros(shape=[8,4], dtype=tf.float32))

 rows_tf = tf.constant (
[[1, 2, 5],
 [1, 2, 5],
 [1, 2, 5],
 [1, 4, 6],
 [1, 4, 6],
 [2, 3, 6],
 [2, 3, 6],
 [2, 4, 7]])

columns_tf = tf.constant(
[[1],
 [2],
 [3],
 [2],
 [3],
 [2],
 [3],
 [2]])

changed_tensor = [[8.3356,    0.,        8.457685 ],
                  [0.,        6.103182,  8.602337 ],
                  [8.8974,    7.330564,  0.       ],
                  [0.,        3.8914037, 5.826657 ],
                  [8.8974,    0.,        8.283971 ],
                  [6.103182,  3.0614321, 5.826657 ],
                  [7.330564,  0.,        8.283971 ],
                  [6.103182,  3.8914037, 0.       ]]

Кроме того, это тензор sparse_indices,который представляет собой конкат rows_tf и columns_tf, делающий целые индексы, которые необходимо обновить (в случае, если это может помочь:)

sparse_indices = tf.constant(
[[1 1]
 [2 1]
 [5 1]
 [1 2]
 [2 2]
 [5 2]
 [1 3]
 [2 3]
 [5 3]
 [1 2]
 [4 2]
 [6 2]
 [1 3]
 [4 3]
 [6 3]
 [2 2]
 [3 2]
 [6 2]
 [2 3]
 [3 3]
 [6 3]
 [2 2]
 [4 2]
 [4 2]])

Что я хочу сделать, это выполнить это простое задание:

out[rows_tf, columns_tf] = changed_tensor

для этого я делаю это:

out[rows_tf:column_tf].assign(changed_tensor)

Однако я получил эту ошибку:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Expected begin, end, and strides to be 1D equal size tensors, but got shapes [1,8,3], [1,8,1], and [1] instead. [Op:StridedSlice] name: strided_slice/

это ожидаемый результат:

[[0.        0.        0.        0.       ]
 [0.        8.3356    0.        8.8974   ]
 [0.        0.        6.103182  7.330564 ]
 [0.        0.        3.0614321 0.       ]
 [0.        0.        3.8914037 0.       ]
 [0.        8.457685  8.602337  0.       ]
 [0.        0.        5.826657  8.283971 ]
 [0.        0.        0.        0.       ]]

Есть идеи, как мне закончить эту миссию?

Заранее спасибо:)

1 Ответ

2 голосов
/ 19 июня 2019

Этот пример (расширенный из документации tf tf.scatter_nd_update здесь ) должен помочь.

Вы хотите сначала объединить ваши row_indices и column_indices в список двумерных индексов, который является indices аргументом для tf.scatter_nd_update. Затем вы подали список ожидаемых значений, который составляет updates.

ref = tf.Variable(tf.zeros(shape=[8,4], dtype=tf.float32))
indices = tf.constant([[0, 2], [2, 2]])
updates = tf.constant([1.0, 2.0])

update = tf.scatter_nd_update(ref, indices, updates)
with tf.Session() as sess:
  sess.run(tf.initialize_all_variables())
  print sess.run(update)
Result:

[[ 0.  0.  1.  0.]
 [ 0.  0.  0.  0.]
 [ 0.  0.  2.  0.]
 [ 0.  0.  0.  0.]
 [ 0.  0.  0.  0.]
 [ 0.  0.  0.  0.]
 [ 0.  0.  0.  0.]
 [ 0.  0.  0.  0.]]

Специально для ваших данных,

ref = tf.Variable(tf.zeros(shape=[8,4], dtype=tf.float32))
changed_tensor = [[8.3356,    0.,        8.457685 ],
                  [0.,        6.103182,  8.602337 ],
                  [8.8974,    7.330564,  0.       ],
                  [0.,        3.8914037, 5.826657 ],
                  [8.8974,    0.,        8.283971 ],
                  [6.103182,  3.0614321, 5.826657 ],
                  [7.330564,  0.,        8.283971 ],
                  [6.103182,  3.8914037, 0.       ]]
updates = tf.reshape(changed_tensor, shape=[-1])
sparse_indices = tf.constant(
[[1, 1],
 [2, 1],
 [5, 1],
 [1, 2],
 [2, 2],
 [5, 2],
 [1, 3],
 [2, 3],
 [5, 3],
 [1, 2],
 [4, 2],
 [6, 2],
 [1, 3],
 [4, 3],
 [6, 3],
 [2, 2],
 [3, 2],
 [6, 2],
 [2, 3],
 [3, 3],
 [6, 3],
 [2, 2],
 [4, 2],
 [4, 2]])

update = tf.scatter_nd_update(ref, sparse_indices, updates)
with tf.Session() as sess:
  sess.run(tf.initialize_all_variables())
  print sess.run(update)

Result:

[[ 0.          0.          0.          0.        ]
 [ 0.          8.3355999   0.          8.8973999 ]
 [ 0.          0.          6.10318184  7.33056402]
 [ 0.          0.          3.06143212  0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          8.45768547  8.60233688  0.        ]
 [ 0.          0.          5.82665682  8.28397083]
 [ 0.          0.          0.          0.        ]]
...