Попытка напечатать названия классов для породы собак, но он продолжает указывать индекс списка вне диапазона - PullRequest
0 голосов
/ 06 марта 2019

Я использую модель resnet для классификации пород собак, но когда я пытаюсь распечатать изображение с ярлыком породы собаки, в нем указывается индекс списка вне диапазона.Вот мой код:

import torchvision.models as models
import torch.nn as nn


model_transfer = models.resnet18(pretrained=True)

if use_cuda:
    model_transfer = model_transfer.cuda()

model_transfer.fc.out_features = 133

Затем я тренирую модель и получаю более 70% точности по породам собак.

Тогда вот мой код, чтобы классифицировать собаку и распечатать породу собаки:

data_transfer = {'train': 
 datasets.ImageFolder('/data/dog_images/train',transform=transforms.Compose([transforms.RandomResizedCrop(224),transforms.ToTensor()]))}
class_names[0]
class_names = [item[4:].replace("_", " ") for item in data_transfer['train'].classes]

def predict_breed_transfer(img_path):

    image = Image.open(img_path)

    # large images will slow down processing


    in_transform = transforms.Compose([
                        transforms.CenterCrop(224),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])])

    # discard the transparent, alpha channel (that's the :3) and add the batch dimension
    image = in_transform(image)[:3,:,:].unsqueeze(0)

    image = image

    output = model_transfer(image)
    pred = torch.argmax(output)

    return class_names[pred]
    predict_breed_transfer('images/Labrador_retriever_06455.jpg')

Код всегда предсказывает, что собака по какой-то причине неверна Затем, когда я пытаюсь распечатать изображение иlabel:

import matplotlib.pyplot as plt
def run_app(img_path):
    img = Image.open(img_path)
    dog = dog_detector(img_path)
    if not dog: 
        print('hello, human!')
        plt.imshow(img)
        print('You look like a ... ')
        print(predict_breed_transfer(img_path))
    if dog: 
        print('hello, dog!')
        print('Your predicted breed is ....')
        print(predict_breed_transfer(img_path))
        plt.imshow(img)
    else: 
        print('Niether human nor dog')

И запустить цикл for, который вызывает его на изображениях некоторых собак, он выведет некоторые породы, затем скажет, что индекс списка находится вне диапазона и не отобразит ни одного изображения.

Длина class_names равна 133. И когда я распечатываю модель resnet, получается только 133 узла. Кто-нибудь знает, почему он говорит, что индекс списка находится вне диапазона или почему он так неточен.

`IndexError                                Traceback (most recent 
call last)
<ipython-input-26-473a9ba884b5> in <module>()
      5 ## suggested code, below
      6 for file in np.hstack((human_files[:3], dog_files[:3])):
----> 7     run_app(file)
      8 
 <ipython-input-25-1d44200e44cc> in run_app(img_path)
      10         plt.show(img)
      11         print('You look like a ... ')
 ---> 12         print(predict_breed_transfer(img_path))
      13     if dog:
      14         print('hello, dog!')

 <ipython-input-20-a51fb205659e> in predict_breed_transfer(img_path)
      26     pred = torch.argmax(output)
      27 
 ---> 28     return class_names[pred]
      29 
predict_breed_transfer('images/Labrador_retriever_06455.jpg')
      30 

IndexError: list index out of range`

Вот полная ошибка

1 Ответ

0 голосов
/ 06 марта 2019

Полагаю, у вас есть несколько проблем, которые можно исправить с помощью 13 символов.

Сначала я предлагаю то, что предложил @Alekhya Vemavarapu - запустите ваш код с помощью отладчика, чтобы изолировать каждую строку и проверить вывод. Это одно из величайших преимуществ динамических графов с pytorch .

Во-вторых, наиболее вероятная причина вашей проблемы - неправильный оператор argmax. Вы не указываете размер, над которым вы выполняете argmax, и поэтому PyTorch автоматически выравнивает изображение и выполняет операцию над вектором полной длины. Таким образом, вы получите число от 0 до MB_Size x num_classes -1.См. Официальный документ по этому методу .

Итак, из-за вашего полностью подключенного слоя я предполагаю, что ваш вывод имеет форму (MB_Size, num_classes).Если это так, вам нужно изменить код на следующую строку:

pred = torch.argmax(output,dim=1)

и все.В противном случае просто выберите размер логитов.

Третье, что вы хотите учесть, - это отсев и другие влияния, которые конфигурация обучения может оказать на вывод.Например, для исключения из некоторых структур может потребоваться умножить выходной сигнал на 1/(1-p) в выводе (или нет, так как это может быть сделано во время обучения), нормализация партии может быть отменена, поскольку размер пакета отличается, и так далее.Кроме того, чтобы уменьшить потребление памяти, градиенты не должны рассчитываться. К счастью, разработчики PyTorch очень вдумчивы и предоставили нам torch.no_grad() и model.eval() для этого.

Я настоятельно рекомендую попробовать это, возможно, изменив код с помощью несколькихбуквы:

output = model_transfer.eval()(image)

и все готово!

Редактировать :
Это простой случай неправильного использования инфраструктуры PyTorch, а не чтение документы и не отлаживать ваш код.Следующий код является абсолютно неправильным:

model_transfer.fc.out_features = 133

Эта строка фактически не создает новый полностью связанный слой.Это просто меняет свойство этого тензора.Попробуйте в консоли:

import torch
a = torch.nn.Linear(1,2)
a.out_features = 3
print(a.bias.data.shape, a.weight.data.shape)

Вывод:

torch.Size([2]) torch.Size([2, 1])

, который указывает, что фактическая матрица весов и вектор смещений остаются в своем исходном измерении.
Правильный путьвыполнить обучение переноса означает сохранить магистраль (обычно сверточные слои до полностью связанных в этих моделях моделей) и перезаписать голову (в данном случае слой FC) своей.Если это только один полностью связанный слой, который существует в исходной модели, вам не нужно изменять прямой проход вашей модели, и вы готовы к работе.Поскольку этот ответ уже достаточно длинный, просто посмотрите учебник Transfer в документации по PyTorch, чтобы узнать, как это можно сделать.

Удачи вам.

...