Как построить имя класса из набора данных tenorflow? - PullRequest
0 голосов
/ 09 февраля 2020

Я взял функцию с веб-сайта TensorFlow, чтобы показать пакет изображений в моем блокноте. Я хочу напечатать его так, как это показано на веб-сайте с классами изображений выше. Вот код функции:

def show_batch(image_batch, label_batch):
plt.figure(figsize=(10,10))
for n in range(25):
    ax = plt.subplot(5,5,n+1)
    plt.imshow(image_batch[n])
    plt.title(CLASS_NAMES[label_batch[n]==1][0].title())
    plt.axis('off')

Проблема с строкой plt.title .... Я получаю сообщение об ошибке: Невозможно преобразовать 1 в EagerTensor типа dol bool

Я не понимаю, в чем проблема, поскольку я обработал свои данные именно так, как это было сделано в учебнике для веб-сайта.

label возвращает массив формы: [False False True False] и должен напечатать имя класса (у меня есть 4 класса) в соответствии с этим. Но это не так. Остальная часть функции работает просто отлично, но бесполезно показывать только изображения, а не имена классов, к которым принадлежит каждое изображение.

1 Ответ

1 голос
/ 10 февраля 2020

Я не нашел красивый способ сделать это, поэтому я сделал это с дополнительным for-l oop. Я прошел через серию меток и сохранил индекс с истинным значением.

def show_batch(image_batch, label_batch):
plt.figure(figsize=(10,10))
for n in range(25):
    ax = plt.subplot(5,5,n+1)
    plt.imshow(np.squeeze(image_batch[n]), cmap = 'gray')
    ix = 0
    for a in label_batch[n]:
        if a == 1:
            break;
        else:
            ix+=1
    plt.title(CLASS_NAMES[ix].title())
    plt.axis('off')  

Просто чтобы прояснить это на примере:

  • Имена классов следующие: [class1, class2, class3, class4]
  • label_batch - это другой массив [false, false, true, false]
  • в данном случае правый индекс равен 2 (отсчет начинается с 0), а класс, который мне нужен, class3 .
...