Я тренирую модель факела, где хочу заморозить (а затем разморозить) определенные параметры. У меня сложилось впечатление, что простая установка param.requires_grad = False
сделает это sh. Это не относится к оптимизаторам с динамикой. Я знаю, что могу либо создать новый оптимизатор, либо изменить параметры существующего, но ни один из них не позволит мне разморозить параметры (легко) и без сохранения дополнительной ссылки на все параметры, которые оптимизатор первоначально изменял.
Я думаю, что желаемого результата можно достичь, установив моментум_буффер в состоянии оптимизатора на ноль, но я не уверен, как это сделать, к нему нелегко получить доступ.
Код ниже можно использовать для воспроизведения эффектов, закомментировав оба известных «решения».
import torch
import torch.nn as nn
import torchvision
from tqdm import tqdm
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
return x.view((x.size()[0], -1))
def main():
data = torchvision.datasets.MNIST("./data",download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))
]))
data_loader = torch.utils.data.DataLoader(data,
batch_size=1000,
shuffle=True)
net=nn.Sequential(*[Flatten(),
nn.Linear(28*28,100),
nn.ReLU(),
nn.Linear(100,10)])
opt=torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
for e in range(2):
old_params = [p.clone() for p in net.parameters()]
if e == 1:
for j,p in enumerate(net.parameters()):
if j<2:
p.requires_grad = False
# opt.param_groups[0]['params'] = opt.param_groups[0]['params'][2:]
# opt = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
for data, label in tqdm(data_loader):
loss=torch.nn.functional.cross_entropy(net(data),label)
opt.zero_grad()
loss.backward()
opt.step()
print(loss)
new_params=[p.clone() for p in net.parameters()]
change = [(~(p1 == p2).all()).item() for p1, p2 in zip(old_params, new_params)]
print("Epoch: %d \t params changed: %s" % (e, change))
print([p.requires_grad for p in net.parameters()])
if __name__ == '__main__':
main()