загрузка и прогнозирование модели pytorch, AttributeError: объект dict не имеет атрибута предиката - PullRequest
1 голос
/ 06 мая 2019
model = torch.load('/home/ofsdms/san_mrc/checkpoint/best_v1_checkpoint.pt', map_location='cpu')
results, labels = predict_function(model, dev_data, version)

> /home/ofsdms/san_mrc/my_utils/data_utils.py(34)predict_squad()
-> phrase, spans, scores = model.predict(batch)
(Pdb) n
AttributeError: 'dict' object has no attribute 'predict'

Как загрузить сохраненную контрольную точку модели Pytorch и использовать ее для прогнозирования. У меня есть модель, сохраненная в расширении .pt

1 Ответ

1 голос
/ 06 мая 2019

контрольная точка, которую вы сохраняете, обычно представляет собой state_dict: словарь, содержащий значения обученных весов, но не фактическая архитектура сети. Фактический вычислительный граф / архитектура сети описывается как класс Python (полученный из nn.Module).
Для использования обученной модели вам необходимо:

  1. Создание model из класса, реализующего вычислительный граф.
  2. Загрузить сохраненный state_dict в этот экземпляр:

    model.load_state_dict(torch.load('/home/ofsdms/san_mrc/checkpoint/best_v1_checkpoint.pt', map_location='cpu')
    
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...