исключения при загрузке контрольной точки модели PyTorch NN - PullRequest
0 голосов
/ 14 апреля 2020

при вызове функции, определенной в следующей ячейке, выдается исключение ' TypeError: forward () принимает 2 позиционных аргумента, но было дано 9' Этот документ предоставляет дополнительную информацию

def load_checkpoint(chkptJP):
checkpoint = torch.load(chkptJP)
model2 = model1(checkpoint['input_size'],
              checkpoint['output_size'],
              checkpoint['fc1'],
              checkpoint['fc2'],
              checkpoint['optimizer_state_dict'],
              checkpoint['epoch'],
              checkpoint['class_to_idx'],
              checkpoint['learning_rate'])
model2.load_state_dict(checkpoint['state_dict'])
return model2

Код, который выписал контрольную точку, выглядит следующим образом:

checkpoint ={'input_size':512,
         'output_size':102,
         'fc1':256,
         'fc2':102,
         'state_dict': model.state_dict(),
         'optimizer_state_dict': optimizer.state_dict(),
         'epoch': epoch+1,
         'class_to_idx': model.class_to_idx,
         'learning_rate': 0.003}
torch.save(checkpoint,chkptJP)

1 Ответ

0 голосов
/ 14 апреля 2020

Ваша ошибка указывает на то, что model1 - это уже созданная сеть, хотя это должен быть класс. Для получения полной информации см. официальную документацию о сохранении (и всегда обращайтесь к ней, если есть сомнения). Я буду ссылаться на него в ответе, поэтому обязательно проверьте его и поймите, что происходит.

Сохранение общей контрольной точки

Ваш код сохраняет общую контрольную точку . Таким способом вы можете сохранить любой словарь и любую нужную вам информацию (это, как правило, Python 's pickle , и вы также можете настроить его аналогичным образом). У вас много информации, некоторые из них не связаны с самой моделью.

Загрузка общей контрольной точки

Как и вы, вы можете загрузить все эти данные через torch.load. Поскольку вы сохранили state_dict (веса), а не весь Model (как выглядит код), вы должны создать новую модель со случайными весами и затем загрузить ее.

Этот код должен быть штраф:

def load_checkpoint(chkptJP):
    checkpoint = torch.load(chkptJP)
    model = ModelClass(
        checkpoint["input_size"],
        checkpoint["output_size"],
        checkpoint["fc1"],
        checkpoint["fc2"],
        checkpoint["optimizer_state_dict"],
        checkpoint["epoch"],
        checkpoint["class_to_idx"],
        checkpoint["learning_rate"],
    )
    model.load_state_dict(checkpoint["state_dict"])
    return model

Обратите внимание ModelClass должен быть классом, а не объектом, как вы это сделали здесь. Если model1 является объектом, то при запуске model1(arg1, ..., arg9) будет вызываться его метод __call__, который, в свою очередь, представляет собой упакованный метод forward, если model1 является экземпляром torch.nn.Module. ModelClass должно быть примерно таким в вашем коде (и, вероятно, где-то определено):

import torch


class ModelClass(torch.nn.Module):
    def __init__(
        self,
        input_size,
        output_size,
        fc1,
        fc2,
        optimizer_state_dict,
        epoch,
        class_to_idx,
        learning_rate,
    ):
        # Your initialization code here
        ...

    def forward(tensor):
        # Your forward pass here
        ...

Если у вас нет ModelClass, вам нужно сохранить всю модель отдельно (например, torch.save(model) вместо torch.save(model.state_dict())) и загрузить его целиком (torch.load(PATH) вместо chkp=torch.load(PATH), за которым следует model.load_state_dict, вызываемый для экземпляра)

...