Как определить неправильную классификацию с партиями в Pytorch - PullRequest
0 голосов
/ 20 сентября 2019

У меня есть такой скрипт, в котором использовались пакеты изображений

correct = 0
total = 0
incorrect_classification=[]
for (i, [images, labels]) in enumerate(test_loader):
  images = Variable(images.view(-1, n_pixel*n_pixel))
  outputs = net(images)
  _, predicted = torch.min(outputs.data, 1)
  total += labels.size(0)                    
  correct += (predicted == labels).sum() 
print('Accuracy: %d %%' %
      (100 * correct / total))

При размере пакета 10 каждое перечисление возвращает 10-кратный тензор размера изображения.Как я могу сохранить все неправильные классификации в массивах invalid_classification или false img и их вероятности в словаре, чтобы я мог использовать plt.imshow, чтобы проверить их позже?

Если размер партии равен 1, я мог бы использовать это:

if (predicted==labels).item()==0:
    incorrect_examples.append(images.numpy())

Но с указанным размером партии (например, 100 изображений в партии), как сохранить неправильные классификации?

Заранее спасибо за любые ответы.

1 Ответ

0 голосов
/ 22 сентября 2019

Как уже сказано в комментарии @zihaozhihao, images[predicted==labels] должен выполнить работу.

Другими словами, вы получите маску индексов и затем получите доступ к нужным изображениям с помощью этой маски:

correct = 0
total = 0
incorrect_examples=[]
for (i, [images, labels]) in enumerate(test_loader):
    images = Variable(images.view(-1, n_pixel*n_pixel))
    outputs = net(images)
    _, predicted = torch.min(outputs.data, 1)
    total += labels.size(0)                    
    correct += (predicted == labels).sum() 
    print('Accuracy: %d %%' % (100 * correct / total))

    # if (predicted==labels).item()==0:
    #     incorrect_examples.append(images.numpy())

    idxs_mask = (predicted == labels).view(-1)
    incorrect_examples.append(images[idxs_mask].numpy()) 

view(-1) сгладит маску, которая будет использоваться для маскировки пакетного канала тензора изображений.

В конце цикла (вне его) итены в списке incorrect_examples будут иметь форму [batch_size, n_pixel, n_pixel], и для удобства вы можете сгруппировать их все в один тензор, объединяя их:

incorrect_images = torch.cat(incorrect_examples)
# incorrect_images.size() -> (n_incorrect_images, n_pixel, n_pixel)
...