Как обучить множественные потери в одной модели с замораживанием какой-то части сети опционально в pytorch? - PullRequest
1 голос
/ 19 июня 2020

Я новичок в pytorch. Я реализую сеть с двумя классификаторами (желтым и фиолетовым), как показано на рисунке . Проблема в том, что я хочу заморозить красную часть, когда сеть тренирует желтый классификатор, и разморозить красную часть, когда сеть тренирует фиолетовый классификатор.

Краткий код, который я предполагаю реализовать, приведен ниже

# x is input and y_yellow, y_purple are labels of yellow and purple classifiers respectively.
criterion = CrossEntropyLoss()
opt = SGD()
model = my_model()

opt.zero_grad()
yellow_out, purple_out = model(x)

# freeze red part requires_grad = False

yellow_loss = criterion(yellow_out, y_yellow)
yellow_loss.backward()
opt.step()

opt.zero_grad()
# unfreeze red part requires_grad = True

purple_loss = criterion(purple_out, y_purple)
purple_loss.backward()
opt.step

Пожалуйста, дайте мне знать точный способ реализации идеи.

  • Правильно ли последовательность кода?

  • Правильно ли я использовал zero_grad?

  • Я что-то пропустил?

  • Есть ли какой-нибудь дополнительный параметр, который мне нужно использовать?

1 Ответ

0 голосов
/ 20 июня 2020

Простейшей стратегией было бы не передавать параметры красного модуля оптимизатору opt. В качестве альтернативы вы можете установить requires_grad как False для параметров красного модуля.

Это будет выглядеть примерно так:

for param in red_module.parameters():
    param.requires_grad = False

Сообщите мне, работает ли это для вас .

...