Полагаю, это то, что вы сделали по ошибке. Вы сохранили функцию
torch.save(model.state_dict, 'model_state.pth')
вместо state_dict ()
torch.save(model.state_dict(), 'model_state.pth')
В противном случае все должно работать как положено. (Я тестировал следующий код на Colab)
Замените model.state_dict()
на model.state_dict
, чтобы воспроизвести ошибку
import copy
model = TheModelClass()
torch.save(model.state_dict(), 'model_state.pth')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.load_state_dict(copy.deepcopy(torch.load("model_state.pth",device)))