Факел: обратные градиенты без обновления переменных - PullRequest
1 голос
/ 20 марта 2020

Есть ли в pytorch эффективный способ обратного распространения градиентов, но не обновления их соответствующих переменных? Кажется, делать копию весов каждый раз во время обновления слишком дорого. Но no_grad & set_grad_enabled не допускает обратного распространения.

Пример. Следующее, кажется, занимает слишком много времени, так как необходимо делать копию модели при каждом обновлении весов:

    def __init__():
        …
        self.model = MyModel()
        self.func1 = FuncModel1()
        self.func2 = FuncModel2()
        …
    def trainstep(input):
        f1 = self.func1(input)
        f2 = self.func2(input)
        …
        # want to update weights in model & f1 with respect to loss1
        loss1 = my_loss(model(f1), y1) 

        # don’t want to update weights in self.model with respect to loss2
        # but want to update weights in f2 for loss2
        copy_model = MyModel()
        copy_model.load_state_dict(self.model.state_dict())
        loss2 = my_loss(copy_model(f2), y2) 

        total_loss = loss1 + loss2
        …

        total_loss.backward()
        optimizer.step()

1 Ответ

0 голосов
/ 22 марта 2020

когда loss.backward() pytorch распространяет градиенты по всему графу вычислений.
Однако сама функция backward() не обновляет весовые коэффициенты, она только вычисляет градиенты.
Обновление выполняется через оптимизатор в optimizer.step(). Если вы хотите исключить веса f1 и f2 из обновлений, вы можете просто либо
- init optimizer без параметров f1 и f2.
- установить обучение ставка для f1 и f2 до нуля.

...