Я думаю, что если вы сгладите x
и var
, tf.scatter_update
сможет это сделать.И тогда вы можете изменить его.
x = tf.random_uniform((5,3), minval=0, maxval=1, dtype=tf.float32)
var = tf.Variable(tf.zeros((5 , 3), tf.int32))
flattenx = tf.Variable(tf.reshape(x, [-1]))
flattenvar = tf.Variable(tf.reshape(var, [-1]))
minusone = tf.squeeze(tf.where(tf.less(flattenx, .25)))
plusone = tf.squeeze(tf.where(tf.greater(flattenx, .75)))
with tf.Session() as sess :
sess.run(tf.global_variables_initializer())
#create the required number of -1's using tile and update
update = tf.scatter_update(flattenvar,
minusone,
tf.tile(tf.constant([-1],
tf.int32),
tf.shape(minusone)))
#create the required number of 1's using tile and update
update = tf.scatter_update(update,
plusone,
tf.tile(tf.constant([1],
tf.int32),
tf.shape(plusone)))
print(sess.run(tf.transpose(tf.reshape(update,(3,5)))))
print(sess.run(tf.transpose(tf.reshape(flattenx,(3,5)))))
Я так транспонирую, потому что, если я непосредственно изменю его в (5,3) позиции не совпадают.
Вывод это.
[[ 0 1 -1]
[ 0 0 -1]
[ 0 1 0]
[ 1 0 1]
[-1 0 0]]
[[0.53033566 0.8070284 0.0320853 ]
[0.7295474 0.684116 0.18633509]
[0.4942881 0.89346325 0.5873647 ]
[0.96912587 0.5400375 0.8372116 ]
[0.14598823 0.62156534 0.54353106]]