Мне пришлось копаться во внутреннем устройстве GradientTape
, но мне удалось это понять. Поделиться здесь для всех, у кого может быть такая же проблема. Спойлер: это немного взломано!
Прежде всего, что на самом деле происходит при вызове
with tf.GradientTape() as tape:
loss_value = self.loss()
tape.gradient(loss_value, vars)
Чтобы это выяснить, нам нужно проверить функции __enter__()
и __exit__()
которые вызываются в начале и в конце блока with
соответственно.
in tensorflow_core/python/eager/backprop.py
def __enter__(self):
"""Enters a context inside which operations are recorded on this tape."""
self._push_tape()
return self
def __exit__(self, typ, value, traceback):
"""Exits the recording context, no further operations are traced."""
if self._recording:
self._pop_tape()
Мы можем сами использовать эти частные функции для управления записью без необходимости a with
block.
# Initialize outer and inner tapes
self.gt_outer = tf.GradientTape(persistent=True)
self.gt_inner = tf.GradientTape(persistent=True)
# Begin Recording
self.gt_outer._push_tape()
self.gt_inner._push_tape()
# evaluate loss which uses self.variables
loss_val = self.loss()
# stop recording on inner tape
self.gt_inner._pop_tape()
# Evaluate the gradient on the inner tape
self.gt_grad = self.gt_inner.gradient(loss_val, self.variables)
# Stop recording on the outer tape
self.gt_outer._pop_tape()
Теперь всякий раз, когда нам нужно оценить произведение вектора Гесса, мы можем повторно использовать внешнюю градиентную ленту.
def hessian_v_prod(self, v):
self.gt_outer._push_tape()
v_hat = tf.reduce(tf.multiply(v, self.gt_grad))
self.gt_outer._pop_tape()
return self.gt_outer.gradient(v_hat, self.variables)
Обратите внимание, что мы сохраняем ленты, поэтому каждый раз, когда оценивается гессианский векторный продукт, он использует больше памяти. Невозможно сохранить часть ленточной памяти, поэтому в определенные моменты возникает необходимость перезагрузить ленты.
# reset tapes
self.gt_outer._tape = None
self.gt_inner._tape = None
Чтобы использовать их снова, после этого нам нужно переоценить внутренний l oop. Он не идеален, но выполняет свою работу и дает значительное ускорение (почти в 2 раза) за счет большего использования памяти.