Tensorflow перезаписывает область в пользовательском слое - PullRequest
0 голосов
/ 02 февраля 2019

Я пытаюсь реализовать шумный линейный слой в тензорном потоке, унаследованный от tf.keras.layers.Layer.Все работает отлично, за исключением повторного использования переменных.Похоже, это связано с некоторой проблемой с областью видимости: всякий раз, когда я использую функцию add_weight из суперкласса, а вес с таким именем уже существует, он, кажется, игнорирует заданный флаг повторного использования в области и вместо этого создает новую переменную.Интересно, что он не добавляет 1 к имени переменной в конце, как обычно в аналогичных случаях, а скорее добавляет 1 к имени области.

import tensorflow as tf

class NoisyDense(tf.keras.layers.Layer):
    def __init__(self,output_dim):
        self.output_dim=output_dim

        super(NoisyDense, self).__init__()

    def build(self, input_shape):
        self.input_dim = input_shape.as_list()[1]
        self.noisy_kernel = self.add_weight(name='noisy_kernel',shape=  (self.input_dim,self.output_dim))

def noisydense(inputs, units):

    layer = NoisyDense(units)

    return layer.apply(inputs)

inputs = tf.placeholder(tf.float32, shape=(1, 10),name="inputs")



scope="scope"
with tf.variable_scope(scope):
    inputs3 = noisydense(inputs,
           1)
    my_variable = tf.get_variable("my_variable", [1, 2, 3],trainable=True)


with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
    inputs2 = noisydense(inputs,
           1)
    my_variable = tf.get_variable("my_variable", [1, 2, 3],trainable=True)

tvars = tf.trainable_variables()



init=tf.global_variables_initializer()    
with tf.Session() as sess:
    sess.run(init)
    tvars_vals = sess.run(tvars)

for var, val in zip(tvars, tvars_vals):
    print(var.name, val)

Это приводит к переменным

  scope/noisy_dense/noisy_kernel:0
   scope_1/noisy_dense/noisy_kernel:0
   scope/my_variable:0

печатается.Я бы хотел использовать ядро ​​с шумом вместо создания второго, как это делается для my_variable.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...