Как я могу исправить или заморозить подмножество фильтра свертки в Pytorch? - PullRequest
1 голос
/ 09 апреля 2020
for batch_idx, (inputs, targets) in enumerate(testloader):
    optimizer.zero_grad()

    inputs = inputs.float().cuda()
    inputs, targets = inputs.to(device), targets.to(device)


    outputs = im_net(inputs)


    ce_loss = criterion(outputs, targets)

    loss = criterion(outputs, targets)


    loss.backward()
    im_net.features[25].weight.grad[temp_5.values] = 0

    print(im_net.features[25].weight.grad[temp_5.values])


    optimizer.step()

Выше приведен мой код для попытки исправить подмножество (temp_5.values) фильтра 25-го слоя свертки. Я пытаюсь установить градиент фильтра поднабора как 0, чтобы не обновлять. Когда я пытаюсь напечатать градиент, все они равны нулю, но вес изменился.

Итак, два вопроса:

  1. Интересно, почему нулевой градиент приводит к изменению веса;
  2. Есть ли другое решение для исправления фильтра свертки подмножества?

1 Ответ

0 голосов
/ 09 апреля 2020

Способ сделать это в pytorch - установить для require_grad значение False.

Эта строка должна помочь:

im_net.features[25].weight.requires_grad = False
...