Я тренирую модель, используя трансферное обучение на основе Resnet152. Основываясь на учебнике по PyTorch, у меня нет проблем с сохранением обученной модели и загрузкой ее для вывода. Однако время, необходимое для загрузки модели, невелико. Я не знаю, правильно ли я это сделал, вот мой код:
Чтобы сохранить обученную модель как состояние dict:
torch.save(model.state_dict(), 'model.pkl')
Чтобы загрузить его для вывода:
model = models.resnet152()
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(classes))
st = torch.load('model.pkl', map_location='cuda:0' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(st)
model.eval()
Я рассчитал код и обнаружил, что загрузка первой строки model = models.resnet152()
занимает больше всего времени. На CPU требуется 10 секунд для проверки одного изображения. Так что я думаю, что это может быть неправильный способ загрузить его?
Если я сохраню всю модель вместо state.dict следующим образом:
torch.save(model, 'model_entire.pkl')
и протестируйте это так:
model = torch.load('model_entire.pkl')
model.eval()
на том же компьютере для проверки одного изображения требуется всего 5 секунд.
Так что мой вопрос: это правильный способ загрузить state_dict для вывода? Спасибо