Сохранение и загрузка контрольной точки модели Pytorch для вывода не работает - PullRequest
0 голосов
/ 19 января 2019

У меня есть обученная модель, использующая LSTM.Модель обучается на GPU (в Google COLABORATORY).Я должен сохранить модель для вывода;который я буду запускать на CPU .После обучения я сохранил контрольную точку модели следующим образом:

torch.save({'model_state_dict': model.state_dict()},'lstmmodelgpu.tar')

И, для вывода, я загрузил модель как:

# model definition
vocab_size = len(vocab_to_int)+1 
output_size = 1
embedding_dim = 300
hidden_dim = 256
n_layers = 2

model = SentimentLSTM(vocab_size, output_size, embedding_dim, hidden_dim, n_layers)

# loading model
device = torch.device('cpu')
checkpoint = torch.load('lstmmodelgpu.tar', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

Но возникает следующая ошибка:

model.load_state_dict(checkpoint['model_state_dict'])
  File "workspace/envs/envdeeplearning/lib/python3.5/site-packages/torch/nn/modules/module.py", line 719, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for SentimentLSTM:
    Missing key(s) in state_dict: "embedding.weight". 
    Unexpected key(s) in state_dict: "encoder.weight".

Что я пропустил при сохранении контрольной точки?

1 Ответ

0 голосов
/ 22 января 2019

Здесь нужно учесть две вещи.

  1. Вы упомянули, что вы тренируете свою модель на графическом процессоре и используете ее для вывода на процессор, поэтому вам нужно добавить параметр map_location в load передача функции torch.device ( 'CPU') .

  2. Существует несоответствие ключей state_dict (указано в вашем выходном сообщении), которое может быть вызвано некоторыми отсутствующими ключами или наличием большего количества ключей в state_dict , который вы загружаете, чем используемая вами модель В настоящее время. И для этого вы должны добавить параметр строгий со значением False в функции load_state_dict . Это заставит метод игнорировать несоответствие ключей.

Примечание: попробуйте использовать расширение pt или pth для файлов контрольных точек, как это принято.

...