Ваша ошибка указывает на то, что model1
- это уже созданная сеть, хотя это должен быть класс. Для получения полной информации см. официальную документацию о сохранении (и всегда обращайтесь к ней, если есть сомнения). Я буду ссылаться на него в ответе, поэтому обязательно проверьте его и поймите, что происходит.
Сохранение общей контрольной точки
Ваш код сохраняет общую контрольную точку . Таким способом вы можете сохранить любой словарь и любую нужную вам информацию (это, как правило, Python 's pickle , и вы также можете настроить его аналогичным образом). У вас много информации, некоторые из них не связаны с самой моделью.
Загрузка общей контрольной точки
Как и вы, вы можете загрузить все эти данные через torch.load
. Поскольку вы сохранили state_dict
(веса), а не весь Model
(как выглядит код), вы должны создать новую модель со случайными весами и затем загрузить ее.
Этот код должен быть штраф:
def load_checkpoint(chkptJP):
checkpoint = torch.load(chkptJP)
model = ModelClass(
checkpoint["input_size"],
checkpoint["output_size"],
checkpoint["fc1"],
checkpoint["fc2"],
checkpoint["optimizer_state_dict"],
checkpoint["epoch"],
checkpoint["class_to_idx"],
checkpoint["learning_rate"],
)
model.load_state_dict(checkpoint["state_dict"])
return model
Обратите внимание ModelClass
должен быть классом, а не объектом, как вы это сделали здесь. Если model1
является объектом, то при запуске model1(arg1, ..., arg9)
будет вызываться его метод __call__
, который, в свою очередь, представляет собой упакованный метод forward
, если model1
является экземпляром torch.nn.Module
. ModelClass
должно быть примерно таким в вашем коде (и, вероятно, где-то определено):
import torch
class ModelClass(torch.nn.Module):
def __init__(
self,
input_size,
output_size,
fc1,
fc2,
optimizer_state_dict,
epoch,
class_to_idx,
learning_rate,
):
# Your initialization code here
...
def forward(tensor):
# Your forward pass here
...
Если у вас нет ModelClass
, вам нужно сохранить всю модель отдельно (например, torch.save(model)
вместо torch.save(model.state_dict()))
и загрузить его целиком (torch.load(PATH)
вместо chkp=torch.load(PATH)
, за которым следует model.load_state_dict
, вызываемый для экземпляра)