Как передать параметры в функцию пересылки моего факела nn.module из skorch.NeuralNetClassifier.fit () - PullRequest
1 голос
/ 14 марта 2019

Я расширил nn.Module для реализации своей сети, функция переадресации которой выглядит следующим образом ...

def forward(self, X, **kwargs):

    batch_size, seq_len = X.size()

    length = kwargs['length']
    embedded = self.embedding(X) # [batch_size, seq_len, embedding_dim]
    if self.use_padding:
        if length is None:
            raise AttributeError("Length must be a tensor when using padding")
        embedded = nn.utils.rnn.pack_padded_sequence(embedded, length, batch_first=True)
        #print("Size of Embedded packed", embedded[0].size())


    hidden, cell = self.init_hidden(batch_size)
    if self.rnn_unit == 'rnn':
        out, _ = self.rnn(embedded, hidden)
    elif self.rnn_unit == 'lstm':
        out, (hidden, cell) = self.rnn(embedded, (hidden, cell))


    # unpack if padding was used
    if self.use_padding:
        out, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first = True)

Я инициализировал скорч NeuralNetClassifier вот так,

net = NeuralNetClassifier(
    model,
    criterion=nn.CrossEntropyLoss,
    optimizer=Adam, 
    max_epochs=8, 
    lr=0.01, 
    batch_size=32
)

Теперь, если я вызову net.fit(X, y, length=X_len), он выдаст ошибку

TypeError: __call__() got an unexpected keyword argument 'length'

В соответствии с документацией функция соответствия ожидает словарь fit_params,

**fit_params : dict
   Additional parameters passed to the ``forward`` method of
   the module and to the ``self.train_split`` call.

иИсходный код всегда отправляет мои параметры в train_split, где очевидно, что аргумент моего ключевого слова не будет распознан.

Есть ли способ передать аргументы моей функции forward?

1 Ответ

0 голосов
/ 27 марта 2019

Параметр fit_params предназначен для передачи информации, относящейся к разбиениям данных и модели, например, к разделенным группам.

В вашем случае вы передаете дополнительные данные вмодуль через fit_params, который не является тем, для чего он предназначен.Фактически, вы могли бы легко столкнуться с проблемами, если бы вы, например, включили в этом случае загрузчик данных поезда, включив перетасовку пакетов, поскольку ваши длины и ваши данные выровнены неправильно.

Лучший способ сделать это уже описанв ответе на ваш вопрос о системе отслеживания проблем :

X_dict = {'X': X, 'length': X_len}
net.fit(X_dict, y)

Поскольку skorch поддерживает dict s, вы можете просто добавить длину к входному диктову и передать его оба вмодуль, красиво упакованный и пропущенный через тот же загрузчик данных.В вашем модуле вы можете получить к нему доступ через параметры в forward:

def forward(self, X, length):
     return ...

Более подробную документацию по этому поведению можно найти в документации .

...