Маскирование переменной Tensorflow в Tensorflow 2 - PullRequest
0 голосов
/ 26 мая 2020

Я хочу замаскировать переменную Tensorflow во время ее определения. В Tensorflow 1 это можно сделать путем умножения на маску W. Однако, если я сделаю это в Tensorflow 2, переменная станет необучаемой. Есть ли какие-либо предложения, как я могу замаскировать переменную Tensorflow?

Если я использую следующий код, он делает var необучаемым.

self.var=tf.Variable(initial_value=tf.random.uniform(self.W_mask.shape,0,1),
                            trainable=True,dtype='float32',name='var')
self.var=tf.multiply(self.var,self.W_mask)
...