отслеживать все веса и градиенты в сети факелов? - PullRequest
0 голосов
/ 23 февраля 2020

Я пытаюсь овладеть статистическими (средними, стандартными) свойствами моих слоев во время обучения.
В настоящее время я использую прямой и обратный хуки, которые, как я считаю, есть у тысяч людей. написано ранее, но не смог найти никаких ссылок на такие репозитории.
Мой код, который просто объединяет все веса, выходные данные и градиенты в словари (и в дальнейшем визуализируется):

def track_network(Net,moduleList):
    outputs = {}
    weights = {}
    gradients = {}
    layers = list(dict(moduleList.named_children()).values())

    def get_activation(name):
        def hook(model, input, output):
            outputs.setdefault(model.__dict__['module_name'], []).append(output.detach().cpu().view(-1).numpy())
            weights.setdefault(model.__dict__['module_name'], []).append(model.weight.detach().cpu().view(-1).numpy())
        return hook

    def get_grad_hook(name):
        def grad_hook(model, grad_input, grad_output):
            gradients.setdefault(model.__dict__['module_name'], []).append(grad_output)
        return grad_hook

    def register_childern(children_dict,parent=""):
        for module_name, module in children_dict.items():
            if not isinstance(module, torch.nn.Linear):
                register_childern(dict(module.named_children()),parent+module_name)
            else:
                module.__dict__['module_name'] = parent+module_name
                module.register_forward_hook(get_activation(module_name))
                module.register_backward_hook(get_grad_hook(module_name))

    register_childern(dict(Net.named_children()))

Мой вопрос: существует ли стандартный / прямой способ отслеживания такие атрибуты сети?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...