В настоящее время я работаю над проектом, и мне нужно провести тренинг по градиентному спуску по проекту для набора данных CIFAR10. В ходе оценки мне нужно визуализировать прогнозируемые классы по сравнению с чистыми данными. Однако я сталкиваюсь с этой ошибкой всякий раз, когда пытаюсь распечатать предсказанные изображения.
Коды оценки представлены ниже.
# Eval
model.eval()
test_loss = 0
correct = 0
images_so_far = 0
with torch.no_grad():
for data, target in test_loader:
data, target = Variable(data), Variable(target)
output = model(data)
val_loss = F.cross_entropy(output, target)
test_loss += val_loss.item() # sum up batch loss
pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
correct += pred.eq(target.data.view_as(pred)).cpu().sum()
Визуализация градиента с учетом ввода
classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
out_num = 5
adv_list = []
pred_list = []
with torch.no_grad():
adv_list.append(output.cpu().numpy().squeeze() * 255.0) # (N, 28, 28)
pred_list.append(pred.cpu().numpy())
data = data.squeeze() # (N, 28, 28)
data *= 255.0
adv_list.insert(0, data)
pred_list.insert(0, target)
types = ['Benchmark', 'Clean Query']
fig, _axs = plt.subplots(nrows=len(adv_list), ncols=out_num)
axs = _axs
for j, _type in enumerate(types):
axs[j, 0].set_ylabel(_type)
for i in range(out_num):
axs[j, i].set_xlabel('%s' % classes[pred_list[j][i]])
img = adv_list[j][i]
img = np.transpose(img, (1, 2, 0))
img = img.astype(np.uint8)
axs[j, i].imshow(img)
axs[j, i].get_xaxis().set_ticks([])
axs[j, i].get_yaxis().set_ticks([])
plt.tight_layout()
ОШИБКА ниже