Pytorch: можем ли мы использовать слои nn.Module непосредственно в функции forward ()? - PullRequest
3 голосов
/ 08 января 2020

Как правило, в конструкторе мы объявляем все слои, которые хотим использовать. В функции forward мы определяем, как будет выполняться модель, от ввода к выводу.

Мой вопрос заключается в том, что если вызывать эти предопределенные / встроенные nn.Modules непосредственно в forward() функция? Является ли этот API-интерфейс функции Keras допустимым для Pytorch стиля? Если нет, то почему?

Обновление: TestModel , созданный таким образом, успешно прошел без тревоги. Но потеря тренировки будет медленно снижаться по сравнению с обычным способом.

import torch.nn as nn
from cnn import CNN

class TestModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.num_embeddings = 2020
        self.embedding_dim = 51

    def forward(self, input):
        x = nn.Embedding(self.num_embeddings, self.embedding_dim)(input)
        # CNN is a customized class and nn.Module subclassed
        # we will ignore the arguments for its instantiation 
        x = CNN(...)(x)
        x = nn.ReLu()(x)
        x = nn.Dropout(p=0.2)(x)
        return output = x

Ответы [ 2 ]

3 голосов
/ 08 января 2020

Вам нужно подумать о области действия обучаемых параметров.

Если вы определите, скажем, слой конвона в функции forward вашей модели, тогда область действия этот «слой» и его обучаемые параметры являются локальными для функции и будут отбрасываться после каждого вызова метода forward. Вы не можете обновлять и обучать веса, которые постоянно сбрасываются после каждого прохода forward.
Однако, когда слой conv является членом вашего model, его область действия выходит за пределы метода forward, а обучаемые параметры сохраняются как Пока существует объект model. Таким образом, вы можете обновить и обучить модель и ее вес.

2 голосов
/ 08 января 2020

То, что вы пытаетесь сделать, можно сделать, но не следует, поскольку в большинстве случаев это совершенно не нужно. И это не более читаемое IMO и определенно против пути PyTorch.

В ваших forward слоях каждый раз происходит повторная инициализация, и они не регистрируются в вашей сети.

Чтобы сделать это правильно, вы можете использовать Module Функция add_module() с защитой от переназначения (метод dynamic ниже):

import torch

class Example(torch.nn.Module):
    def __init__(self):
        self.num_embeddings = 2020        
        self.embedding_dim = 51 

    def dynamic(self, name: str, module_class, *args, **kwargs):
        if not hasattr(self, name):
            self.add_module(name, module_class(*args, **kwargs))
        return getattr(self, name)

    def forward(self, x):
        embedded = self.dynamic("embedding",
                     torch.nn.Embedding,
                     self.num_embeddings,
                     self.embedding_dim)(x)
        return embedded

Вы можете структурировать это по-другому, но за этим стоит идея.

Реально Примером использования может быть случай, когда создание слоя каким-то образом зависит от данных, передаваемых в forward, но это может указывать на некоторое fl aws в разработке программы.

...