Как загрузить и использовать сохраненную модель PyTorch InceptionV3 для классификации изображения - PullRequest
0 голосов
/ 19 декабря 2018

У меня та же проблема, что и Как я могу загрузить и использовать модель PyTorch (.pth.tar) , у которой нет принятого ответа или я могу выяснить, как следовать приведенным советам.

Я новичок в PyTorch.Я пытаюсь загрузить предварительно подготовленную модель PyTorch, указанную здесь: https://github.com/macaodha/inat_comp_2018

Я почти уверен, что мне не хватает клея.

# load the model
import torch
model=torch.load("iNat_2018_InceptionV3.pth.tar",map_location='cpu')

# try to get it to classify an image
imsize = 256
loader = transforms.Compose([transforms.Scale(imsize), transforms.ToTensor()])

def image_loader(image_name):
    """load image, returns cuda tensor"""
    image = Image.open(image_name)
    image = loader(image).float()
    image = Variable(image, requires_grad=True)
    image = image.unsqueeze(0)  
    return image.cpu()  #assumes that you're using CPU

image = image_loader("test-image.jpg")

Выдает ошибку:

in () ----> 1 model.predict (image)

AttributeError: у объекта 'dict' нет атрибута'gnit

1 Ответ

0 голосов
/ 19 декабря 2018

Проблема

Ваш model на самом деле не модель.Когда оно сохранено, оно содержит не только параметры, но и другую информацию о модели в виде формы, несколько похожей на диктовку.

Следовательно, torch.load("iNat_2018_InceptionV3.pth.tar") просто возвращает dict, что, конечно, неиметь атрибут с именем predict.

model=torch.load("iNat_2018_InceptionV3.pth.tar",map_location='cpu')
type(model)
# dict

Решение

Что вам нужно сделать сначала в этом случае, и в общих случаях, это создать экземпляр желаемого класса модели в соответствии софициальное руководство «Загрузка моделей» .

# First try
from torchvision.models import Inception3
v3 = Inception3()
v3.load_state_dict(model['state_dict']) # model that was imported in your code.

Однако прямой ввод model['state_dict'] вызовет некоторые ошибки, связанные с несоответствием форм параметров Inception3.

Важно знать, что было изменено на Inception3 после его создания.К счастью, вы можете найти это в оригинальном авторском train_inat.py.

# What the author has done
model = inception_v3(pretrained=True)
model.fc = nn.Linear(2048, args.num_classes) #where args.num_classes = 8142
model.aux_logits = False

Теперь, когда мы знаем, что изменить, давайте внесем некоторые изменения в нашу первую попытку .

# Second try
from torchvision.models import Inception3
v3 = Inception3()
v3.fc = nn.Linear(2048, 8142)
v3.aux_logits = False
v3.load_state_dict(model['state_dict']) # model that was imported in your code.

И вот вам успешно загруженная модель!

...