Почему существует разный выход между model.forward (вход) и модель (вход) - PullRequest
2 голосов
/ 25 марта 2019

Я использую pytorch для создания простой модели, такой как VGG16, и в моей модели перегружена функция forward.

Я обнаружил, что все склонны использовать model(input) для получения результата, а не model.forward(input), и меня интересует разница между ними. Я пытаюсь ввести те же данные, но результат отличается. Я в замешательстве.

Я вывел layer_weight до того, как я ввел данные, вес не изменился, и я знаю, когда мы используем model(input) его с помощью функции __call__, и эта функция вызовет model.forward.

   vgg = VGG()
   vgg.double()
   for layer in vgg.modules():
      if isinstance(layer,torch.nn.Linear):
         print(layer.weight)
   print("   use model.forward(input)     ")
   result = vgg.forward(array)

   for layer in vgg.modules():
     if isinstance(layer,torch.nn.Linear):
       print(layer.weight) 
   print("   use model(input)     ")
   result_2 = vgg(array)
   print(result)
   print(result_2)

выход:

    Variable containing:1.00000e-02 *
    -0.2931  0.6716 -0.3497 -2.0217 -0.0764  1.2162  1.4983 -1.2881
    [torch.DoubleTensor of size 1x8]

    Variable containing:
    1.00000e-02 *
    0.5302  0.4494 -0.6866 -2.1657 -0.9504  1.0211  0.8308 -1.1665
    [torch.DoubleTensor of size 1x8]

1 Ответ

5 голосов
/ 25 марта 2019

model.forward просто вызывает операции пересылки, как вы упомянули, но __call__ делает немного больше.

Если вы покопаетесь в коде класса nn.Module, вы увидите, что __call__ в конечном счете вызывает переадресацию, но внутренне обрабатывает перехваты вперед или назад и управляет некоторыми состояниями, которые допускает pytorch. При вызове простой модели, такой как просто MLP, она может и не понадобиться, но более сложные модели, такие как слои спектральной нормализации, имеют зацепки, и поэтому вы должны использовать сигнатуру model(.) как можно чаще, если вы явно не хотите просто вызывать model.forward

Также см. Вызов функции переадресации без .forward ()

В этом случае, однако, разница может быть связана с некоторым выпадающим слоем, вам следует позвонить vgg.eval(), чтобы убедиться, что вся стохастичность в сети отключена перед сравнением выходных данных.

...