Как можно заморозить параметры горелки, если используется оптимизатор с импульсом? - PullRequest
2 голосов
/ 28 апреля 2020

Я тренирую модель факела, где хочу заморозить (а затем разморозить) определенные параметры. У меня сложилось впечатление, что простая установка param.requires_grad = False сделает это sh. Это не относится к оптимизаторам с динамикой. Я знаю, что могу либо создать новый оптимизатор, либо изменить параметры существующего, но ни один из них не позволит мне разморозить параметры (легко) и без сохранения дополнительной ссылки на все параметры, которые оптимизатор первоначально изменял.

Я думаю, что желаемого результата можно достичь, установив моментум_буффер в состоянии оптимизатора на ноль, но я не уверен, как это сделать, к нему нелегко получить доступ.

Код ниже можно использовать для воспроизведения эффектов, закомментировав оба известных «решения».

import torch
import torch.nn as nn
import torchvision
from tqdm import tqdm


class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        return x.view((x.size()[0], -1))

def main():
    data = torchvision.datasets.MNIST("./data",download=True,
                                       transform=torchvision.transforms.Compose([
                                       torchvision.transforms.ToTensor(),
                                       torchvision.transforms.Normalize((0.1307,), (0.3081,))
                                        ]))
    data_loader = torch.utils.data.DataLoader(data,
                                              batch_size=1000,
                                              shuffle=True)

    net=nn.Sequential(*[Flatten(),
                    nn.Linear(28*28,100),
                    nn.ReLU(),
                    nn.Linear(100,10)])
    opt=torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

    for e in range(2):
        old_params = [p.clone() for p in net.parameters()]
        if e == 1:
            for j,p in enumerate(net.parameters()):
                if j<2:
                    p.requires_grad = False
            # opt.param_groups[0]['params'] = opt.param_groups[0]['params'][2:]

        # opt = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

        for data, label in tqdm(data_loader):
            loss=torch.nn.functional.cross_entropy(net(data),label)
            opt.zero_grad()
            loss.backward()
            opt.step()
        print(loss)

        new_params=[p.clone() for p in net.parameters()]
        change = [(~(p1 == p2).all()).item() for p1, p2 in zip(old_params, new_params)]
        print("Epoch: %d \t params changed: %s" % (e, change))
        print([p.requires_grad for p in net.parameters()])


if __name__ == '__main__':
    main()
...