Найти все слои ReLU в модели Torchvision - PullRequest
0 голосов
/ 04 октября 2018

После того, как я выбрал предварительно обученную модель из torchvision.models, я хочу, чтобы все экземпляры ReLU были register_backward_hook(f), например:

for pos, module in self.model.features._modules.items():
    for sub_module in module:
        if isinstance(module, ReLU):
            module.register_backward_hook(f)

Для меня проблема в том, какнайти все ReLU в модели.Для densenet161, ReLU существует не только в model.features._modules, но также и в самоопределяемом плотном слое, например.model.features._modules['denseblock1'][0].Для resnet151 ReLU существует в model._modules и его самоопределенном слое, например, model._modules['layer1'].

Есть ли способ найти все ReLU внутри модели?

1 Ответ

0 голосов
/ 04 октября 2018

Более элегантный способ перебора всех компонентов модели - использование метода modules():

from torch import nn

for module in self.model.modules():
  if isinstance(module, nn.ReLU):
    module.register_backward_hook(f)

Если вы не хотите получать все субмодули, толькоБлижайшие, вы можете рассмотреть возможность использования метода children() вместо modules().Вы также можете получить имя подмодуля, используя метод named_modules().

...