Модель предсказания Кераса - как получить первый элемент - PullRequest
0 голосов
/ 30 апреля 2020

Привет, я делаю маленькую модель, которая предсказывает фрукты. У меня есть функция, которая запускает серию pred для различных изображений, и они выводят прогноз, как показано ниже.

[[0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]

Я знаю, что вывод не является массивом, однако я хотел бы знать, есть ли возможный способ проверить значения в определенных позициях. Например, если бы это был массив, я бы сделал:

if prediction[0] == 1:
    print(prediction, "Apple")

Однако, поскольку это не так, я понятия не имею, как проверить значения внутри него. Есть ли способ, которым я могу проверить?

Функция такова:

def fruit_prediction(image_dir):
    img_list = os.listdir(image_dir)
    print(img_list)
    for fruits in img_list:
        path = os.path.join(image_dir, fruits)
        img = image.load_img(path, target_size = (150, 150))
        array = image.img_to_array(img)
        x = np.expand_dims(array, axis=0)

        vimage = np.vstack([x])
        prediction = model.predict(vimage)
        print(prediction, fruits)

Ответы [ 2 ]

1 голос
/ 30 апреля 2020

Использование np.argmax () может решить вашу проблему

здесь np.argmax (прогноз) вернет индекс с наибольшей вероятностью.

напр.

#let your food rep index in one hot encoding like below.
fruit={0:'apple',1:'orange',2:'banana'....like that}

Теперь у вас есть прогноз, как вы рассчитали в приведенной выше функции

 index=np.argmax(prediction)
 print("fruit name",fruit[index])
0 голосов
/ 30 апреля 2020

Привет, ребята. Мне удалось это исправить с помощью функцииpretion.argmax (), например,

if prediction.argmax() == 0:
    print(np.around(prediction, 3), "\n(File: ", fruit, ") Prediction: Apple\n")
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...