keras.model.predict повысить ValueError: Ошибка при проверке ввода - PullRequest
0 голосов
/ 14 апреля 2019

Я обучил базовую модель нейронной сети на наборе данных MNIST.Вот код тренинга: (импорт опущен)

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data(path='mnist.npz')
x_train, x_test = x_train/255.0, x_test/255.0

#1st Define the model
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape = (28,28)),     #input layer
    tf.keras.layers.Dense(512, activation=tf.nn.relu),  #main computation layer
    tf.keras.layers.Dropout(0.2),                       #Dropout layer to avoid overfitting
    tf.keras.layers.Dense(10, activation=tf.nn.softmax) #output layer / Softmax is a classifier AF
])

#2nd Compile the model
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

#3rd Fit the model
model.fit(x_train, y_train, epochs=5)

#4th Save the model
model.save('models/mnistCNN.h5')

#5th Evaluate the model
model.evaluate(x_test, y_test)

Я хотел посмотреть, как эта модель работает с моими собственными данными, поэтому я написал сценарий прогнозирования с помощью этого поста ,Мой код предсказания: (импорт опущен)

model = load_model('models/mnistCNN.h5')

for i in range(3):
    img = Image.open(str(i+1) + '.png').convert("L")
    img = img.resize((28,28))
    im2arr = np.array(img)
    im2arr = im2arr/255
    im2arr = im2arr.reshape(1, 28, 28, 1)
    y_pred = model.predict(im2arr)
    print('For Image',i+1,'Prediction = ',y_pred)

Во-первых, я не понимаю цели этой строки:

im2arr = im2arr.reshape(1, 28, 28, 1)

Если кто-то может пролить свет на то, почему эта строканеобходимо, это было бы очень полезно.

Во-вторых, эта самая строка выдает следующую ошибку:

ValueError: Error when checking input: expected flatten_input to have 3 dimensions, but got array with shape (1, 28, 28, 1)

Чего мне здесь не хватает?

1 Ответ

1 голос
/ 14 апреля 2019

Первое измерение используется для размера партии.Это добавляется keras.model внутри.Так что эта строка просто добавляет его в массив изображений.

im2arr = im2arr.reshape(1, 28, 28, 1)

Ошибка, которую вы получаете, состоит в том, что единственный пример из mnist dataset, который вы использовали для обучения, имеет форму (28, 28), так же как и ваш входной слой.Чтобы избавиться от этой ошибки, вам нужно изменить эту строку на

im2arr = img.reshape((1, 28, 28))
...