Вероятность тензорного потока - тренировка бижектора - PullRequest
0 голосов
/ 06 августа 2020

Я пытался следовать примеру из этого учебного пособия , но у меня возникли проблемы с обучением какой-либо из переменных.

Я написал небольшой пример, но я не был может выполнить эту работу либо:

# Train a shift bijector
shift = tf.Variable(initial_value=tf.convert_to_tensor([1.0], dtype=tf.float32), trainable=True, name='shift_var')
bijector = tfp.bijectors.Shift(shift=shift)

# Input
x = tf.convert_to_tensor(np.array([0]), dtype=tf.float32)
target = tf.convert_to_tensor(np.array([2]), dtype=tf.float32)

optimizer = tf.optimizers.Adam(learning_rate=0.5)
nsteps = 1

print(bijector(x).numpy(), bijector.shift)
for _ in range(nsteps):

    with tf.GradientTape() as tape:
        out = bijector(x)
        loss = tf.math.square(tf.math.abs(out - target))
        #print(out, loss)
    
        gradients = tape.gradient(loss, bijector.trainable_variables)
    
    optimizer.apply_gradients(zip(gradients, bijector.trainable_variables))
    
print(bijector(x).numpy(), bijector.shift)

Для nsteps = 1 два оператора печати приводят к следующему выводу:

[1.] <tf.Variable 'shift_var:0' shape=(1,) dtype=float32, numpy=array([1.], dtype=float32)>
[1.] <tf.Variable 'shift_var:0' shape=(1,) dtype=float32, numpy=array([1.4999993], dtype=float32)>

Кажется, что bijector все еще использует исходный shift, хотя напечатанное значение bijector.shift было обновлено.

Я не могу увеличить nsteps, так как градиент None после первой итерации, и я получил эту ошибку:

ValueError: No gradients provided for any variable: ['shift_var:0'].

Я использую

tensorflow version 2.3.0
tensorflow-probability version 0.11.0

Я также пробовал его на ноутбуке colab, поэтому сомневаюсь, что это проблема версии.

Ответы [ 2 ]

1 голос
/ 24 августа 2020

Вы нашли ошибку. Функция прямого бижектора слабо кэширует отображение результат-> вход, чтобы сделать обратные обратные потоки и лог-детерминанты быстрыми. Но как-то это мешает градиенту. Обходной путь - добавление del out, например https://colab.research.google.com/gist/brianwa84/04249c2e9eb089c2d748d05ee2c32762/bijector-cache-bug.ipynb

0 голосов
/ 07 августа 2020

Все еще не уверен, что я точно понимаю, что здесь происходит, но, по крайней мере, я могу заставить свой пример работать сейчас.

По какой-то причине поведение будет другим, если я заключу его в класс, унаследованный от tf.keras.Model:

class BijectorModel(tf.keras.Model):

    def __init__(self):
        super().__init__()

        self.shift = tf.Variable(initial_value=tf.convert_to_tensor([1.5], dtype=tf.float32), trainable=True, name='shift_var')
        self.bijector = tfp.bijectors.Shift(shift=self.shift)

    def call(self, input):
        return self.bijector(input)

Я сделал функцию для обучающей итерации, хотя в этом нет необходимости:

def training_iteration(model, input, target):

    optimizer = tf.optimizers.SGD(learning_rate=0.1)

    with tf.GradientTape() as tape:

        loss = tf.math.square(tf.math.abs(model(input) - target))

        gradients = tape.gradient(loss, model.trainable_variables)

    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

Выполнение как это

x = tf.convert_to_tensor(np.array([0]), dtype=tf.float32)
target = tf.convert_to_tensor(np.array([2]), dtype=tf.float32)
model = BijectorModel()

nsteps = 10
for _ in range(nsteps):
    training_iteration(model, x, target)
    print('Iteration {}: Output {}'.format(_, model(x)))

дает ожидаемый / желаемый результат:

Iteration 0: Output [1.6]
Iteration 1: Output [1.6800001]
Iteration 2: Output [1.7440001]
Iteration 3: Output [1.7952001]
Iteration 4: Output [1.8361601]
Iteration 5: Output [1.8689281]
Iteration 6: Output [1.8951424]
Iteration 7: Output [1.916114]
Iteration 8: Output [1.9328911]
Iteration 9: Output [1.9463129]

Я пришел к выводу, что обучаемые переменные обрабатываются иначе, когда они являются частью модели, по сравнению с доступом через объект-биектор.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...