У меня есть вычислительный граф PyTorch, который состоит из подграфа, выполняющего некоторое вычисление, и результат этого вычисления (назовем его x
) затем разветвляется на два других подграфа.Каждый из этих двух подграфов дает некоторые скалярные результаты (назовем их y1
и y2
).Я хочу сделать обратный проход для каждого из этих двух результатов (то есть я хочу накапливать градиенты двух подграфов. Я не хочу выполнять фактический шаг оптимизации).
Теперь,так как память здесь является проблемой, я хотел бы выполнить операции в следующем порядке: во-первых, рассчитать x
.Затем вычислите y1
и выполните y1.backward()
, в то время как (и это ключевой момент) , сохраняя график, ведущий к x
, но освобождая график от x
до y1
.Затем вычислите y2
и выполните y2.backward()
.
Другими словами, чтобы сэкономить память, не жертвуя слишком большой скоростью, я хочу сохранить x
без необходимости пересчитывать ее, но я хочу отброситьвсе вычисления, начинающиеся с x
до y1
после того, как они мне больше не нужны.
Проблема в том, что аргумент retain_graph
функции backward()
сохранит весь граф, ведущий к y1
, тогда как мне нужно сохранить только часть графика, ведущую к x
.
Вот пример того, что я в идеале хотел бы:
import torch
w = torch.tensor(1.0)
w.requires_grad_(True)
# sub-graph for calculating `x`
x = w+10
# sub-graph for calculating `y1`
x1 = x*x
y1 = x1*x1
y1.backward(retain_graph=x) # this would not work, since retain_graph is a boolean and can either retain the entire graph or free it.
# sub-graph for calculating `y2`
x2 = torch.sqrt(x)
y2 = x2/2
y2.backward()
Как это можно сделать