Как правильно загрузить модель переноса обучения для вывода в PyTorch? - PullRequest
1 голос
/ 25 мая 2019

Я тренирую модель, используя трансферное обучение на основе Resnet152. Основываясь на учебнике по PyTorch, у меня нет проблем с сохранением обученной модели и загрузкой ее для вывода. Однако время, необходимое для загрузки модели, невелико. Я не знаю, правильно ли я это сделал, вот мой код:

Чтобы сохранить обученную модель как состояние dict:

torch.save(model.state_dict(), 'model.pkl')

Чтобы загрузить его для вывода:

model = models.resnet152()
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(classes))
st = torch.load('model.pkl', map_location='cuda:0' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(st)
model.eval()

Я рассчитал код и обнаружил, что загрузка первой строки model = models.resnet152() занимает больше всего времени. На CPU требуется 10 секунд для проверки одного изображения. Так что я думаю, что это может быть неправильный способ загрузить его?

Если я сохраню всю модель вместо state.dict следующим образом:

torch.save(model, 'model_entire.pkl')

и протестируйте это так:

model = torch.load('model_entire.pkl')
model.eval()

на том же компьютере для проверки одного изображения требуется всего 5 секунд.

Так что мой вопрос: это правильный способ загрузить state_dict для вывода? Спасибо

1 Ответ

0 голосов
/ 25 мая 2019

Это зависит от того, что вы хотите сделать с моделью позже.

С целью возобновления обучения

Если вы хотите сохранить модель для ее обученияпозже вам понадобится больше, чем state_dict самой модели.Вам также необходимо сохранить состояние оптимизатора, эпох, баллов и т. Д.

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

# load
model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

Для вывода

Только для модели state_dictбудет достаточно.Но убедитесь, что вы вызываете режим eval после загрузки модели, чтобы слои batchnorm или dropout работали в режиме eval вместо режима обучения.

torch.save(model.state_dict(), filepath)

# load
model.load_state_dict(torch.load(filepath))
model.eval()

См. Официальное руководство по передовым методам если вам это нужно.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...