Я новичок в pytorch. Я реализую сеть с двумя классификаторами (желтым и фиолетовым), как показано на рисунке . Проблема в том, что я хочу заморозить красную часть, когда сеть тренирует желтый классификатор, и разморозить красную часть, когда сеть тренирует фиолетовый классификатор.
Краткий код, который я предполагаю реализовать, приведен ниже
# x is input and y_yellow, y_purple are labels of yellow and purple classifiers respectively.
criterion = CrossEntropyLoss()
opt = SGD()
model = my_model()
opt.zero_grad()
yellow_out, purple_out = model(x)
# freeze red part requires_grad = False
yellow_loss = criterion(yellow_out, y_yellow)
yellow_loss.backward()
opt.step()
opt.zero_grad()
# unfreeze red part requires_grad = True
purple_loss = criterion(purple_out, y_purple)
purple_loss.backward()
opt.step
Пожалуйста, дайте мне знать точный способ реализации идеи.
Правильно ли последовательность кода?
Правильно ли я использовал zero_grad?
Я что-то пропустил?
Есть ли какой-нибудь дополнительный параметр, который мне нужно использовать?