Как обновить один столбец 2d tf.Variable? - PullRequest
0 голосов
/ 03 мая 2020

Допустим, у меня есть MxN-образный tf.Variable, в котором хранится некоторое состояние моего пользовательского слоя:

import tensorflow as tf
m, n = 3, 4  # just for example
v = tf.Variable(tf.zeros([m, n]), trainable=False)

# v = <tf.Variable 'Variable:0' shape=(3, 4) dtype=float32, numpy=
#     array([[0., 0., 0., 0.],
#            [0., 0., 0., 0.],
#            [0., 0., 0., 0.]], dtype=float32)>

Я знаю, что могу обновить значения этой переменной с помощью v.assign(...) , но как я могу обновить только подраздел этой переменной? Например, я хотел бы вставить данный вектор в данный столбец.

x = tf.ones([m,1])
c = tf.Variable(2)
# update v by inserting x at column c

... так, чтобы следующие значения были новыми значениями v:

# v = <tf.Variable 'Variable:0' shape=(3, 4) dtype=float32, numpy=
#     array([[0., 0., 1., 0.],
#            [0., 0., 1., 0.],
#            [0., 0., 1., 0.]], dtype=float32)>

1 Ответ

1 голос
/ 03 мая 2020

с ТФ 2,2

m, n = 3, 4  # just for example
v = tf.Variable(tf.zeros([m, n]), trainable=False)
x = tf.ones(m)
c = 2

change_v = v[:,c].assign(x)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...