for batch_idx, (inputs, targets) in enumerate(testloader):
optimizer.zero_grad()
inputs = inputs.float().cuda()
inputs, targets = inputs.to(device), targets.to(device)
outputs = im_net(inputs)
ce_loss = criterion(outputs, targets)
loss = criterion(outputs, targets)
loss.backward()
im_net.features[25].weight.grad[temp_5.values] = 0
print(im_net.features[25].weight.grad[temp_5.values])
optimizer.step()
Выше приведен мой код для попытки исправить подмножество (temp_5.values
) фильтра 25-го слоя свертки. Я пытаюсь установить градиент фильтра поднабора как 0, чтобы не обновлять. Когда я пытаюсь напечатать градиент, все они равны нулю, но вес изменился.
Итак, два вопроса:
- Интересно, почему нулевой градиент приводит к изменению веса;
- Есть ли другое решение для исправления фильтра свертки подмножества?