Я настраиваю модель BERT, используя библиотеку трансформеров, и обучаю ее на GPU в облаке. Затем я сохраняю модель и токенизатор, как показано ниже:
model.save_pretrained('/saved_model/')
torch.save(best_model.state_dict(), '/saved_model/model')
tokenizer.save_pretrained('/saved_model/')
Я загружаю каталог saved_model
на свой компьютер. Затем я загружаю модель / токенайзер, как показано ниже, в мой компьютер
import torch
from transformers import *
tokenizer = BertTokenizer.from_pretrained('./saved_model/')
config = BertConfig('./saved_model/config.json')
model = BertModel(config)
model.load_state_dict(torch.load('./saved_model/pytorch_model.bin', map_location=torch.device('cpu')))
model.eval()
Но он выдает ошибку ниже для model.load_state_dict
строки
RuntimeError: Error(s) in loading state_dict for BertModel:
Missing key(s) in state_dict:
В ней перечислены несколько ключей, которые, по-видимому,отсутствует в state_dict.
Я новичок в pytorch и не уверен, что происходит. Скорее всего, я не сохраняю модель правильным способом.
Пожалуйста, предложите.