PyTorch: различные передовые методы для обучения и тестирования / проверки - PullRequest
0 голосов
/ 01 ноября 2019

В настоящее время я пытаюсь расширить модель , основанную на FairSeq / PyTorch. Во время обучения мне нужно обучить два кодера: один с целевой выборкой и исходный с исходной выборкой.

Таким образом, текущая функция пересылки выглядит следующим образом:

def forward(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
    encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
    decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
    return decoder_out

И на основена этой этой идее я хочу что-то вроде этого:

def forward_test(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
    encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
    decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
    return decoder_out

def forward_train(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
    encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
    autoencoder_out = self.encoder(tgt_tokens, src_lengths=src_lengths, **kwargs)
    concat = some_concatination_func(encoder_out, autoencoder_out)
    decoder_out = self.decoder(prev_output_tokens, encoder_out=concat, **kwargs)
    return decoder_out

Есть ли способ сделать это?

Редактировать: это ограничения, которые у меня есть, поскольку мне нужно расширить FairseqEncoderDecoderModel :

@register_model('transformer_mass')
class TransformerMASSModel(FairseqEncoderDecoderModel):
    def __init__(self, encoder, decoder):
        super().__init__(encoder, decoder) 

Редактировать 2: параметры, переданные функции пересылки в Fairseqможет быть изменен путем реализации вашего собственного критерия, см., например, CrossEntropyCriterion , где sample['net_input'] передается функции __call__ модели, которая вызывает метод forward.

Ответы [ 2 ]

3 голосов
/ 01 ноября 2019

Прежде всего вы должны всегда использовать и определять forward, а не некоторые другие методы, которые вы вызываете для экземпляра torch.nn.Module.

Определенно не перегружайте eval(), как показано trsvchn , поскольку его метод оценки определен PyTorch ( см. Здесь ). Этот метод позволяетслои внутри вашей модели должны быть переведены в режим оценки (например, определенные изменения в слоях, такие как режим вывода для Dropout или BatchNorm).

Более того, вы должны вызывать его с помощью __call__ магического метода. Почему? Поскольку хуки и другие специфичные для PyTorch вещи регистрируются таким образом правильно.

Во-вторых, не используйте какую-либо внешнюю строковую переменную mode, как предложено @ Anant Mittal . Для этого и используется переменная train в PyTorch, по ней принято различать, находится ли модель в режиме eval или train.

При этом лучше всего делать это следующим образом:

import torch


class Network(torch.nn.Module):
    def __init__(self):
        super().__init__()
        ...

    # You could split it into two functions but both should be called by forward
    def forward(
        self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs
    ):
        encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
        if self.train:
            return self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
        autoencoder_out = self.encoder(tgt_tokens, src_lengths=src_lengths, **kwargs)
        concat = some_concatination_func(encoder_out, autoencoder_out)
        return self.decoder(prev_output_tokens, encoder_out=concat, **kwargs)

Вы можете (и, возможно, должны) разделить вышеперечисленное на два отдельных метода, но это не так уж и плохо, так как функция довольно короткая и удобочитаемая. Просто придерживайтесь метода обработки PyTorch, если это возможно, а не каких-то специальных решений. И нет, с обратным распространением проблем не будет, с чего бы это?

0 голосов
/ 01 ноября 2019

По умолчанию вызов model() вызывает метод forward, который в вашем случае обучается вперед, поэтому вам просто нужно определить новый метод для вашего пути test / eval внутри класса модели, что-то вроде этого:

Код:

class FooBar(nn.Module):
    """Dummy Net for testing/debugging.
    """

    def __init__(self):
        super().__init__()
        ...

    def forward(self, x):
        # here will be train forward
        ...

    def evaltest(self, x):
        # here will be eval/test forward
        ...

Примеры:

model = FooBar()  # initialize model 

# train time
pred = model(x)   # calls forward() method under the hood

# test/eval time
test_pred = model.evaltest(x)

Комментарий: Я хотел бы рекомендовать вам разделить эти два прямых пути на 2 отдельных метода, потому чтоего легче отлаживать и избежать возможных проблем при обратном распространении.

...