Я взял функцию с веб-сайта TensorFlow, чтобы показать пакет изображений в моем блокноте. Я хочу напечатать его так, как это показано на веб-сайте с классами изображений выше. Вот код функции:
def show_batch(image_batch, label_batch):
plt.figure(figsize=(10,10))
for n in range(25):
ax = plt.subplot(5,5,n+1)
plt.imshow(image_batch[n])
plt.title(CLASS_NAMES[label_batch[n]==1][0].title())
plt.axis('off')
Проблема с строкой plt.title .... Я получаю сообщение об ошибке: Невозможно преобразовать 1 в EagerTensor типа dol bool
Я не понимаю, в чем проблема, поскольку я обработал свои данные именно так, как это было сделано в учебнике для веб-сайта.
label возвращает массив формы: [False False True False] и должен напечатать имя класса (у меня есть 4 класса) в соответствии с этим. Но это не так. Остальная часть функции работает просто отлично, но бесполезно показывать только изображения, а не имена классов, к которым принадлежит каждое изображение.