Keras - Как использовать argmax для прогнозов - PullRequest
0 голосов
/ 13 января 2019

У меня есть 3 категории классов Tree, Stump, Ground. Я сделал список для этих категорий:

CATEGORIES = ["Tree", "Stump", "Ground"]

Когда я печатаю свой прогноз, он дает мне вывод

[[0. 1. 0.]]

Я читал об Argmax от numpy, но я не совсем уверен, как использовать его в этом случае.

Я пытался использовать

print(np.argmax(prediction))

Но это дает мне вывод 1. Это здорово, но я хотел бы узнать, что такое индекс 1, а затем распечатать категорию вместо наибольшего значения.

import cv2
import tensorflow as tf
import numpy as np

CATEGORIES = ["Tree", "Stump", "Ground"]


def prepare(filepath):
    IMG_SIZE = 150 # This value must be the same as the value in Part1
    img_array = cv2.imread(filepath, cv2.IMREAD_GRAYSCALE)
    new_array = cv2.resize(img_array, (IMG_SIZE, IMG_SIZE))
    return new_array.reshape(-1, IMG_SIZE, IMG_SIZE, 1)

# Able to load a .model, .h3, .chibai and even .dog
model = tf.keras.models.load_model("models/test.model")

prediction = model.predict([prepare('image.jpg')])
print("Predictions:")
print(prediction)
print(np.argmax(prediction))

Я ожидаю, что мой прогноз покажет мне:

Predictions:
[[0. 1. 0.]]
Stump

Спасибо за чтение :) Я ценю любую помощь.

1 Ответ

0 голосов
/ 13 января 2019

Вам просто нужно проиндексировать категории с результатом np.argmax:

pred_name = CATEGORIES[np.argmax(prediction)]
print(pred_name)
...