Что делает model.eval () в pytorch? - PullRequest
       5

Что делает model.eval () в pytorch?

3 голосов
/ 01 февраля 2020

Я использую этот код и видел model.eval() в некоторых случаях.

Я понимаю, что это должно позволить мне "оценить мою модель", но я не понять, когда я должен и не должен использовать его, или как выключить, если выключен.

Пожалуйста, просветите меня.

Я хотел бы запустить приведенный выше код для обучения сети, а также быть возможность запуска проверки каждую эпоху. Я все еще не мог этого сделать.

1 Ответ

5 голосов
/ 01 февраля 2020

model.eval() является своего рода переключателем для некоторых определенных c слоев / частей модели, которые ведут себя по-разному во время обучения и вывода (оценки) времени. Например, Dropouts Layers, BatchNorm Layers et c. Вы должны отключить их во время оценки модели, и .eval() сделает это за вас. Кроме того, общая практика оценки / проверки заключается в использовании torch.no_grad() в паре с model.eval() для отключения вычисления градиентов:

# evaluate model:
model.eval()

with torch.no_grad():
    ...
    out_data = model(data)
    ...

НО, не забудьте вернуться в режим training после шага оценки:

# training step
...
model.train()
...
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...