RuntimeError: Ошибки при загрузке state_dict для BertModel - PullRequest
0 голосов
/ 18 октября 2019

Я настраиваю модель BERT, используя библиотеку трансформеров, и обучаю ее на GPU в облаке. Затем я сохраняю модель и токенизатор, как показано ниже:

model.save_pretrained('/saved_model/')
torch.save(best_model.state_dict(), '/saved_model/model')
tokenizer.save_pretrained('/saved_model/')

Я загружаю каталог saved_model на свой компьютер. Затем я загружаю модель / токенайзер, как показано ниже, в мой компьютер

import torch
from transformers import *
tokenizer = BertTokenizer.from_pretrained('./saved_model/')
config = BertConfig('./saved_model/config.json')
model = BertModel(config)
model.load_state_dict(torch.load('./saved_model/pytorch_model.bin', map_location=torch.device('cpu')))
model.eval()

Но он выдает ошибку ниже для model.load_state_dict строки

RuntimeError: Error(s) in loading state_dict for BertModel:
    Missing key(s) in state_dict:

В ней перечислены несколько ключей, которые, по-видимому,отсутствует в state_dict.

Я новичок в pytorch и не уверен, что происходит. Скорее всего, я не сохраняю модель правильным способом.

Пожалуйста, предложите.

1 Ответ

0 голосов
/ 18 октября 2019

Как вы, возможно, знаете, state_dict модуля PyTorch - это OrderedDict. Когда вы пытались загрузить вес модуля из state_dict, он жалуется на отсутствие ключей, что означает, что state_dict не содержит этих ключей. В этой ситуации я бы предложил выполнить следующие действия:

  1. Проверьте, какие ключи присутствуют в state_dict. Невозможно сохранить только подмножество клавиш.
  2. Кроме того, убедитесь, что загружена правильная конфигурация. В противном случае, если ваш обученный BertModel и новый BertModel, для которого вы хотите загрузить веса, отличаются, вы получите эту ошибку.
  3. Наконец, если ваш код проходит оба вышеупомянутых случая, то сохраните модель,убедитесь, что вы сохранили все параметры слоев в файле. Заявление torch.save(best_model.state_dict(), '/saved_model/model') выглядит нормально для меня, но убедитесь, что best_model.state_dict() содержит все ожидаемые ключи.
...