Есть ли в pytorch эффективный способ обратного распространения градиентов, но не обновления их соответствующих переменных? Кажется, делать копию весов каждый раз во время обновления слишком дорого. Но no_grad
& set_grad_enabled
не допускает обратного распространения.
Пример. Следующее, кажется, занимает слишком много времени, так как необходимо делать копию модели при каждом обновлении весов:
def __init__():
…
self.model = MyModel()
self.func1 = FuncModel1()
self.func2 = FuncModel2()
…
def trainstep(input):
f1 = self.func1(input)
f2 = self.func2(input)
…
# want to update weights in model & f1 with respect to loss1
loss1 = my_loss(model(f1), y1)
# don’t want to update weights in self.model with respect to loss2
# but want to update weights in f2 for loss2
copy_model = MyModel()
copy_model.load_state_dict(self.model.state_dict())
loss2 = my_loss(copy_model(f2), y2)
total_loss = loss1 + loss2
…
total_loss.backward()
optimizer.step()