Я пытаюсь составить график 5 лучших прогнозов из классификатора изображений для любого изображения. Мои файлы организованы так:
- Тест / -> имя_класса / -> номер_образа
, например, - Тест / -> шоколад / -> 0015254 .jpg
Когда я использую следующий код, я не могу получить имя класса для перенаправления на печать, это выдает ошибку, которая приводит к выводу каждого класса (показано ниже).
class_to_idx:
classes = os.listdir(train_folder)
classes.sort()
#print(classes)
label_mapping = {k: v for v, k in enumerate(classes)}
class_to_idx = {classes[i]: i for i in range(len(classes))}
Вот код для создания изображения и печати его топ-5 прогнозов:
def sanity_check(path):
imagepath = test_folder + path
image = process_image(imagepath)
plot = imshow(image, ax = plt)
plot.axis('off')
plot.title(class_to_idx[str(classes)])
plot.show()
axes = predict(imagepath, model)
yaxis = [class_to_idx[str(i)] for i in np.array(axes[1][0].cpu())]
y_pos = np.arange(len(yaxis))
xaxis = np.array(axes[0][0].cpu().numpy())
plt.barh(y_pos, xaxis)
plt.xlabel('probability')
plt.yticks(y_pos, yaxis)
plt.title('probability of {} classification'.format(name))
plt.show()
path = '/chocolate_mousse/1379570.jpg'
sanity_check(path)
Вот ошибка:
KeyError Traceback (most recent call last)
<ipython-input-64-5747e27aee19> in <module>()
1 path = '/chocolate_mousse/1379570.jpg'
----> 2 sanity_check(path, name = 'Chocolate Mousse')
<ipython-input-62-295700419574> in sanity_check(path, name)
7 plot = imshow(image, ax = plt)
8 plot.axis('off')
----> 9 plot.title(class_to_idx[str(classes)])
10 plot.show()
11
KeyError: "['apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare', 'beet_salad', 'beignets', .... etc