Предварительное вычисление Pytorch части функции потерь - PullRequest
0 голосов
/ 01 августа 2020

Я пытаюсь использовать pytorch для вычисления сложной функции, но все же требую градиент вывода как функцию входных данных. Например:

a=torch.tensor([1,2,3], dtype=torch.float32, requires_grad=True)
b=torch.tensor([3,2,1], dtype=torch.float32, requires_grad=True)
v=torch.tensor([0.2], dtype=torch.float32, requires_grad=True)

# here precalc represents some (fairly expensive) sequence of operations
precalc = a.dot(b)+a*b+a*a+b*b

def calc(precalc, v):
    z=torch.randn(3,1000)
    batch=v*precalc.matmul(z)
    return torch.relu(batch).mean()

Итак, предварительный тензор c является фиксированной функцией от a и b (но требует больших затрат на вычисление).

Когда я вызываю cal c для первого время как x=calc(precalc, v), а затем x.backward(), я получаю правильные градиенты для a, b и v. Однако, если я вызываю cal c второй раз, т.е. x2=calc(precalc, v), а затем x2.backward(), я получаю pytorch ошибка, я возвращаюсь назад по графику во второй раз, и я должен сохранить график.

В идеале я бы хотел освободить график, так как предварительный c требует больших затрат на вычисление, но мало в памяти и c сам по себе вычислить проще, но для этого потребуется много памяти.

Есть ли способ добиться этого?

Спасибо.

...