Мне нужно вычислить гессианы некоторой модели TensorFlow относительно ее входных данных.
Код выглядит примерно так:
class model()
...
def hessians(self, X):
X = tf.convert_to_tensor(X, dtype=tf.float32)
with tf.GradientTape() as tape1:
tape1.watch(X)
with tf.GradientTape() as tape2:
tape2.watch(X)
y = self(X)
df_dx = tape2.gradient(y,X)
d2f_dx2 = tape1.batch_jacobian(df_dx, X)
return d2f_dx2.numpy()
При вызове параметр X
представляет собой массив numpy, где одна строка представляет один вход. Эта реализация является более или менее копией примера из https://www.tensorflow.org/api_docs/python/tf/GradientTape (для производных более высокого порядка) и дает ожидаемый результат, но есть проблемы с потреблением памяти.
Функция резервирует слишком много памяти, и эта память не освобождается. Я выполнил следующий тест с memory_profiler:
@profile
def workflow(model, N):
h = create_hessians(model, N)
time.sleep(5) #do some stuff
print("Finished")
def create_hessians(model, N):
d = 200
R = np.sqrt(d)
X = np.random.normal(size=(N,d))
h = model.hessians(X=X)
print("Size of hessians = {}B".format(sys.getsizeof(h)))
return h
Вывод:
Line # Mem usage Increment Line Contents
================================================
49 338.4 MiB 338.4 MiB @profile
50 def hessians(self, X):
51 342.2 MiB 3.7 MiB X = tf.convert_to_tensor(X, dtype=tf.float32)
52 342.2 MiB 0.0 MiB with tf.GradientTape() as tape1:
53 342.4 MiB 0.2 MiB tape1.watch(X)
54 342.4 MiB 0.0 MiB with tf.GradientTape() as tape2:
55 342.4 MiB 0.0 MiB tape2.watch(X)
56 352.6 MiB 10.2 MiB y = self(X)
57 366.7 MiB 14.2 MiB df_dx = tape2.gradient(y,X)
58 4493.2 MiB 4126.5 MiB d2f_dx2 = tape1.batch_jacobian(df_dx, X)
59 5256.4 MiB 763.2 MiB return d2f_dx2.numpy()
Size of hessians = 800000128B
Finished
Filename: /tmp/test.py
Line # Mem usage Increment Line Contents
================================================
18 330.9 MiB 330.9 MiB @profile
19 def workflow(model, N):
20 5245.0 MiB 4914.1 MiB h = create_hessians(model, N)
21 5245.0 MiB 0.0 MiB time.sleep(5) #do some stuff
22 5245.0 MiB 0.0 MiB print("Finished")
Если я не ошибаюсь, использование памяти должно упасть до ~ 330 + 800 (размер h) после функция возвращает. Любая помощь приветствуется.