Pytorch: понять, как работает класс nn.Module - PullRequest
3 голосов
/ 11 ноября 2019

Как правило, nn.Module может наследоваться подклассом, как показано ниже.

def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform(m.weight)  # 

class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression, self).__init__()
        self.fc1 = nn.Linear(20, 1)
        self.apply(init_weights)

    def forward(self, x):
        x = self.fc1(x)
        return x

Мой 1-й вопрос: почему я могу просто запустить приведенный ниже код, даже если у моего __init__ нет никаких позитивных аргументов для training_signals, и похоже, что training_signals передается forward() метод. Как это работает?

model = LinearRegression()
training_signals = torch.rand(1000,20)
model(training_signals)

Второй вопрос: как внутренне работает self.apply(init_weights)? Он выполняется перед вызовом forward метода?

1 Ответ

3 голосов
/ 11 ноября 2019

В1: Почему я могу просто запустить приведенный ниже код, даже если у моего __init__ нет позиционных аргументов для training_signals, и похоже, что training_signals передается методу forward(). Как это работает?

Во-первых, __init__ вызывается при запуске этой строки:

model = LinearRegression()

Как видите, вы не передаете никаких параметров и не должны«т. Подпись вашего __init__ такая же, как и у базового класса (который вы вызываете при запуске super(LinearRegression, self).__init__()). Как вы можете видеть здесь , сигнатура инициализации nn.Module просто def __init__(self) (как и ваша).

Во-вторых, model теперь является объектом. Когда вы запускаете строку ниже:

model(training_signals)

Вы фактически вызываете метод __call__ и передаете training_signals в качестве позиционного параметра. Как вы можете видеть здесь , среди прочего метод __call__ вызывает метод forward:

result = self.forward(*input, **kwargs)

, передавая все параметры (позиционные и именованные) __call__ к forward.

Q2: Как внутренне работает self.apply(init_weights)? Выполняется ли он перед вызовом метода пересылки?

PyTorch - это Open Source, так что вы можете просто перейти к исходному коду и проверить его. Как вы можете видеть здесь , реализация довольно проста:

def apply(self, fn):
    for module in self.children():
        module.apply(fn)
    fn(self)
    return self

Цитирование документации функции: она " применяет fn рекурсивно к каждому подмодулю (как возвращенона .children()), а также self ". Основываясь на реализации, вы также можете понять требования:

  • fn должен вызываться;
  • fn получает на вход только объект Module;
...