Cifar10 Project Gradient Descent визуализация обучения состязательности в Pytorch - PullRequest
0 голосов
/ 15 апреля 2020

В настоящее время я работаю над проектом, и мне нужно провести тренинг по градиентному спуску по проекту для набора данных CIFAR10. В ходе оценки мне нужно визуализировать прогнозируемые классы по сравнению с чистыми данными. Однако я сталкиваюсь с этой ошибкой всякий раз, когда пытаюсь распечатать предсказанные изображения.

Коды оценки представлены ниже.

# Eval
model.eval()
test_loss = 0
correct = 0
images_so_far = 0
with torch.no_grad():
   for data, target in test_loader:
     data, target = Variable(data), Variable(target)
     output = model(data)
     val_loss = F.cross_entropy(output, target)
     test_loss += val_loss.item()  # sum up batch loss
     pred = output.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
     correct += pred.eq(target.data.view_as(pred)).cpu().sum() 

Визуализация градиента с учетом ввода

classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
out_num = 5
adv_list = []
pred_list = []
with torch.no_grad():
   adv_list.append(output.cpu().numpy().squeeze() * 255.0)  # (N, 28, 28)
   pred_list.append(pred.cpu().numpy())
   data = data.squeeze()  # (N, 28, 28)
   data *= 255.0
   adv_list.insert(0, data)
   pred_list.insert(0, target)
   types = ['Benchmark', 'Clean Query']
   fig, _axs = plt.subplots(nrows=len(adv_list), ncols=out_num)
   axs = _axs
   for j, _type in enumerate(types):
        axs[j, 0].set_ylabel(_type)
        for i in range(out_num):
           axs[j, i].set_xlabel('%s' % classes[pred_list[j][i]])
           img = adv_list[j][i]
           img = np.transpose(img, (1, 2, 0))
           img = img.astype(np.uint8)
           axs[j, i].imshow(img)
           axs[j, i].get_xaxis().set_ticks([])
           axs[j, i].get_yaxis().set_ticks([])
plt.tight_layout()

ОШИБКА enter image description here ниже

...