Как получить правильный прогноз от нейронной сети, обученной на MNIST от kaggle? - PullRequest
1 голос
/ 27 июня 2019

Я натренировал нейронную сеть на наборе данных MNIST от kaggle. У меня проблемы с получением нейронной сети, чтобы предсказать число, которое она получает.

Я не знаю, что попытаться исправитьэта проблема.

'' 'python

    import pandas as pd
    from tensorflow import keras
    import matplotlib.pyplot as plt
    import numpy as np


    mnist=pd.read_csv(r"C:\Users\Chandrasang\python projects\digit-recognizer\train.csv").values
    xtest=pd.read_csv(r"C:\Users\Chandrasang\python projects\digit-recognizer\test.csv").values

    ytrain=mnist[:,0]
    xtrain=mnist[:,1:]

    x_train=keras.utils.normalize(xtrain,axis=1)
    x_test=keras.utils.normalize(xtest,axis=1)

    x=0
    xtrain2=[]
    while True:
        d=x_train[x]
        d.shape=(28,28)
        xtrain2.append(d)
        x+=1
        if x==42000:
            break

    y=0
    xtest2=[]
    while True:
        b=x_test[y]
        b.shape=(28,28)
        xtest2.append(b)
        y+=1
        if y==28000:
            break

    train=np.array(xtrain2,dtype=np.float32)
    test=np.array(xtest2,dtype=np.float32)

    model=keras.models.Sequential()
    model.add(keras.layers.Flatten())
    model.add(keras.layers.Dense(256,activation=keras.activations.relu))
    model.add(keras.layers.Dense(256,activation=keras.activations.relu))
    model.add(keras.layers.Dense(10,activation=keras.activations.softmax))

    model.compile(optimizer='adam',
                 loss='sparse_categorical_crossentropy',
                 metrics=['accuracy'])
    model.fit(train,ytrain,epochs=10)

    ans=model.predict(x_test)
    print(ans[3])

' ''

Я ожидаю, что вывод будет целым числом, вместо этого он даст мне следующий массив:

[2.7538205e-02 1.0337318e-11 2.9973364e-03 5.7095995e-06 1.6916725e-07 6.9060135e-08 1.3406207e-09 1.1861910e-06 1.4758119e-06 9.6945578e-01]

1 Ответ

1 голос
/ 27 июня 2019

Ваш вывод в норме, это вектор вероятностей. У вас есть 10 классов (цифры от 0 до 9), и ваша сеть вычисляет вероятность того, что ваше изображение будет в каждом классе. С учетом ваших результатов ваша сеть классифицировала ваши входные данные как 9 с вероятностью примерно 0,96.

Если вы хотите увидеть только предсказанный класс, как сказал Крис А., используйте predict_classes.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...