PyTorch: противоречивый предтренированный VGG выход - PullRequest
1 голос
/ 06 мая 2019

При загрузке предварительно обученной сети VGG с модулем torchvision.models и использовании его для классификации произвольного изображения RGB выход сети заметно отличается от вызова к вызову.Почему это происходит?Насколько я понимаю, ни одна часть прямого прохода VGG не должна быть недетерминированной.

Вот MCVE:

import torch
from torchvision.models import vgg16

vgg = vgg16(pretrained=True)

img = torch.randn(1, 3, 256, 256)

torch.all(torch.eq(vgg(img), vgg(img))) # result is 0, but why?

1 Ответ

1 голос
/ 06 мая 2019

vgg16 имеет слой nn.Dropout, который во время тренировки случайным образом сбрасывает 50% своих входных данных.Во время тестирования вы должны «отключить» это поведение, установив режим сети в режим «eval»:

vgg.eval()
torch.all(torch.eq(vgg(img), vgg(img)))
Out[73]: tensor(1, dtype=torch.uint8)

Обратите внимание, что есть другие слои со случайнымповедение и другое поведение для обучения и оценки (например, BatchNorm ).Поэтому важно перейти в режим eval(), прежде чем оценивать обученную модель.

...