Как я могу обновить только некоторые конкретные тензоры в сети с Pytorch? - PullRequest
0 голосов
/ 29 августа 2018

Например, я хочу обновить все веса CNN в Resnet в первые 10 эпох и заморозить остальные.
И с 11-й эпохи я хочу изменить, чтобы обновить всю модель.
Как мне достичь цели?

Ответы [ 2 ]

0 голосов
/ 29 августа 2018

очень просто, так как PYTORCH воссоздает вычислительный граф на лету.

for p in resnet.parameters():
    p.requires_grad = False # this will freeze the module from training suppose that resnet is one of your module

если у вас несколько модулей, просто зациклите их. затем, после 10 эпох, вы просто звоните

for p in network.parameters():
    p.requires_grad = True # suppose your whole network is the 'network' module
0 голосов
/ 29 августа 2018

Вы можете установить скорость обучения (и некоторые другие мета-параметры) для каждой группы параметров. Вам нужно только сгруппировать ваши параметры в соответствии с вашими потребностями.
Например, установка разной скорости обучения для конвексных слоев:

import torch
import itertools
from torch import nn

conv_params = itertools.chain.from_iterable([m.parameters() for m in model.children()
                                             if isinstance(m, nn.Conv2d)])
other_params = itertools.chain.from_iterable([m.parameters() for m in model.children()
                                              if not isinstance(m, nn.Conv2d)]) 
optimizer = torch.optim.SGD([{'params': other_params},
                             {'params': conv_params, 'lr': 0}],  # set init lr to 0
                            lr=lr_for_model)

Позже вы можете получить доступ к оптимизатору param_groups и изменить скорость обучения.

См. опции для каждого параметра для получения дополнительной информации.

...