Дублирование слоев при повторном использовании модели pytorch - PullRequest
5 голосов
/ 08 мая 2020

Я пытаюсь повторно использовать некоторые слои re snet для пользовательской архитектуры и столкнулся с проблемой, которую не могу понять. Вот упрощенный пример; когда я бегу:

import torch
from torchvision import models
from torchsummary import summary

def convrelu(in_channels, out_channels, kernel, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
        nn.ReLU(inplace=True),
    )


class ResNetUNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.base_model = models.resnet18(pretrained=False)
        self.base_layers = list(self.base_model.children())


        self.layer0 = nn.Sequential(*self.base_layers[:3])


    def forward(self, x):
        print(x.shape)

        output = self.layer0(x)

        return output

base_model = ResNetUNet().cuda()
summary(base_model,(3,224,224))

Дает мне:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 112, 112]           9,408
            Conv2d-2         [-1, 64, 112, 112]           9,408
       BatchNorm2d-3         [-1, 64, 112, 112]             128
       BatchNorm2d-4         [-1, 64, 112, 112]             128
              ReLU-5         [-1, 64, 112, 112]               0
              ReLU-6         [-1, 64, 112, 112]               0
================================================================
Total params: 19,072
Trainable params: 19,072
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 36.75
Params size (MB): 0.07
Estimated Total Size (MB): 37.40
----------------------------------------------------------------

Это дублирует каждый слой (есть 2 свертки, 2 батчнорма, 2 relu) вместо того, чтобы давать по одному слою каждому . Если я распечатаю self.base_layers[:3], я получу:

[Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False), BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), ReLU(inplace=True)]

, что показывает только три слоя без дубликатов. Почему он дублирует мои слои?

Я использую pytorch версии 1.4.0

1 Ответ

1 голос
/ 08 мая 2020

Ваши слои на самом деле не вызываются дважды. Это артефакт реализации summary.

Простая причина в том, что summary рекурсивно выполняет итерацию по всем дочерним элементам вашего модуля и регистрирует переадресацию перехватчиков для каждого из них. Поскольку у вас есть повторяющиеся дочерние элементы (в base_model и layer0), эти повторяющиеся модули получают несколько зарегистрированных крючков. Когда сводные вызовы переадресовывают, это приводит к вызову обоих хуков для каждого модуля, что вызывает повторение уровней.


Для вашего игрушечного примера решением было бы просто не назначать base_model как атрибут, так как он все равно не используется во время пересылки. Это позволяет избежать добавления base_model в качестве дочернего.

class ResNetUNet(nn.Module):
    def __init__(self):
        super().__init__()
        base_model = models.resnet18(pretrained=False)
        base_layers = list(base_model.children())
        self.layer0 = nn.Sequential(*base_layers[:3])

Другое решение - создать модифицированную версию summary, которая не регистрирует хуки для одного и того же модуля несколько раз. Ниже приведен расширенный summary, в котором я использую набор с именем already_registered для отслеживания модулей, в которых уже зарегистрированы хуки, чтобы избежать регистрации нескольких хуков.

from collections import OrderedDict
import torch
import torch.nn as nn
import numpy as np

def summary(model, input_size, batch_size=-1, device="cuda"):

    # keep track of registered modules so that we don't add multiple hooks
    already_registered = set()

    def register_hook(module):

        def hook(module, input, output):
            class_name = str(module.__class__).split(".")[-1].split("'")[0]
            module_idx = len(summary)

            m_key = "%s-%i" % (class_name, module_idx + 1)
            summary[m_key] = OrderedDict()
            summary[m_key]["input_shape"] = list(input[0].size())
            summary[m_key]["input_shape"][0] = batch_size
            if isinstance(output, (list, tuple)):
                summary[m_key]["output_shape"] = [
                    [-1] + list(o.size())[1:] for o in output
                ]
            else:
                summary[m_key]["output_shape"] = list(output.size())
                summary[m_key]["output_shape"][0] = batch_size

            params = 0
            if hasattr(module, "weight") and hasattr(module.weight, "size"):
                params += torch.prod(torch.LongTensor(list(module.weight.size())))
                summary[m_key]["trainable"] = module.weight.requires_grad
            if hasattr(module, "bias") and hasattr(module.bias, "size"):
                params += torch.prod(torch.LongTensor(list(module.bias.size())))
            summary[m_key]["nb_params"] = params

        if (
            not isinstance(module, nn.Sequential)
            and not isinstance(module, nn.ModuleList)
            and not (module == model)
            and module not in already_registered:
        ):
            already_registered.add(module)
            hooks.append(module.register_forward_hook(hook))

    device = device.lower()
    assert device in [
        "cuda",
        "cpu",
    ], "Input device is not valid, please specify 'cuda' or 'cpu'"

    if device == "cuda" and torch.cuda.is_available():
        dtype = torch.cuda.FloatTensor
    else:
        dtype = torch.FloatTensor

    # multiple inputs to the network
    if isinstance(input_size, tuple):
        input_size = [input_size]

    # batch_size of 2 for batchnorm
    x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size]
    # print(type(x[0]))

    # create properties
    summary = OrderedDict()
    hooks = []

    # register hook
    model.apply(register_hook)

    # make a forward pass
    # print(x.shape)
    model(*x)

    # remove these hooks
    for h in hooks:
        h.remove()

    print("----------------------------------------------------------------")
    line_new = "{:>20}  {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #")
    print(line_new)
    print("================================================================")
    total_params = 0
    total_output = 0
    trainable_params = 0
    for layer in summary:
        # input_shape, output_shape, trainable, nb_params
        line_new = "{:>20}  {:>25} {:>15}".format(
            layer,
            str(summary[layer]["output_shape"]),
            "{0:,}".format(summary[layer]["nb_params"]),
        )
        total_params += summary[layer]["nb_params"]
        total_output += np.prod(summary[layer]["output_shape"])
        if "trainable" in summary[layer]:
            if summary[layer]["trainable"] == True:
                trainable_params += summary[layer]["nb_params"]
        print(line_new)

    # assume 4 bytes/number (float on cuda).
    total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))
    total_output_size = abs(2. * total_output * 4. / (1024 ** 2.))  # x2 for gradients
    total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.))
    total_size = total_params_size + total_output_size + total_input_size

    print("================================================================")
    print("Total params: {0:,}".format(total_params))
    print("Trainable params: {0:,}".format(trainable_params))
    print("Non-trainable params: {0:,}".format(total_params - trainable_params))
    print("----------------------------------------------------------------")
    print("Input size (MB): %0.2f" % total_input_size)
    print("Forward/backward pass size (MB): %0.2f" % total_output_size)
    print("Params size (MB): %0.2f" % total_params_size)
    print("Estimated Total Size (MB): %0.2f" % total_size)
    print("----------------------------------------------------------------")
    # return summary

Примечание У меня было более подробный ответ, который вы можете найти в правках, но, оглядываясь назад, я думаю, что это более простое объяснение легче понять.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...