Я пытаюсь использовать pytorch для вычисления сложной функции, но все же требую градиент вывода как функцию входных данных. Например:
a=torch.tensor([1,2,3], dtype=torch.float32, requires_grad=True)
b=torch.tensor([3,2,1], dtype=torch.float32, requires_grad=True)
v=torch.tensor([0.2], dtype=torch.float32, requires_grad=True)
# here precalc represents some (fairly expensive) sequence of operations
precalc = a.dot(b)+a*b+a*a+b*b
def calc(precalc, v):
z=torch.randn(3,1000)
batch=v*precalc.matmul(z)
return torch.relu(batch).mean()
Итак, предварительный тензор c является фиксированной функцией от a и b (но требует больших затрат на вычисление).
Когда я вызываю cal c для первого время как x=calc(precalc, v)
, а затем x.backward()
, я получаю правильные градиенты для a, b и v. Однако, если я вызываю cal c второй раз, т.е. x2=calc(precalc, v)
, а затем x2.backward()
, я получаю pytorch ошибка, я возвращаюсь назад по графику во второй раз, и я должен сохранить график.
В идеале я бы хотел освободить график, так как предварительный c требует больших затрат на вычисление, но мало в памяти и c сам по себе вычислить проще, но для этого потребуется много памяти.
Есть ли способ добиться этого?
Спасибо.