Как изменить уровень активации в модуле предварительной подготовки Pytorch? - PullRequest
1 голос
/ 09 октября 2019

Как изменить уровень активации предварительно обученной сети Pytorch? Вот мой код:

print("All modules")
for child in net.children():
    if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
        print(child)

print('Before changing activation')
for child in net.children():
    if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
        print(child)
        child=nn.SELU()
        print(child)
print('after changing activation')
for child in net.children():
    if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
        print(child)

Вот мой вывод:

All modules
ReLU(inplace=True)
Before changing activation
ReLU(inplace=True)
SELU()
after changing activation
ReLU(inplace=True)

Ответы [ 2 ]

2 голосов
/ 09 октября 2019

Я предполагаю, что вы используете интерфейс модуля nn.ReLU для создания слоя активации вместо использования функционального интерфейса F.relu. Если это так, setattr работает для меня.

import torch
import torch.nn as nn

# This function will recursively replace all relu module to selu module. 
def replace_relu_to_selu(model):
    for child_name, child in model.named_children():
        if isinstance(child, nn.ReLU):
            setattr(model, child_name, nn.SELU())
        else:
            replace_relu_to_selu(child)

########## A toy example ##########
net = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(3, 32, kernel_size=3, stride=1),
            nn.ReLU(inplace=True)
          )

########## Test ##########
print('Before changing activation')
for child in net.children():
    if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
        print(child)
# Before changing activation
# ReLU(inplace=True)
# ReLU(inplace=True)


print('after changing activation')
for child in net.children():
    if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
        print(child)
# after changing activation
# SELU()
# SELU(
1 голос
/ 09 октября 2019

._modules решает проблему для меня.

for name,child in net.named_children():
    if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
        net._modules['relu'] = nn.SELU()
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...