Как правильно настроить параметры Conv2D? - PullRequest
0 голосов
/ 19 марта 2020

Я пытаюсь построить модель CNN на наборе данных MNIST, но появляется ошибка, и я не могу ее решить. Это мой код

import tensorflow as tf
from tensorflow.keras.datasets import mnist
from keras.utils import to_categorical
from keras.models import Sequential
from keras.layers import Dense, Flatten, Conv2D
from numpy import expand_dims
(x_train, y_train), (x_test, y_test) = mnist.load_data() 
x_train = expand_dims(x_train, 3)
x_test = expand_dims(x_test, 3)
y_train = to_categorical(y_train) 
y_test = to_categorical(y_test)
model = Sequential()
model.add(Conv2D(64, kernel_size=3, activation="relu", input_shape=(28, 28, 1)))
model.add(Conv2D(64, kernel_size=3, activation="relu"))
model.add(Flatten())
model.add(Dense(10, activation="softmax"))
model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=3)
model.save("model")
new_model = tf.keras.models.load_model("model")
predictions = new_model.predict([x_test])
print(predictions[10])

Я получил эту ошибку

TypeError: Value passed to parameter 'input' has DataType uint8 not in list of allowed values: float16, bfloat16, float32, float64

Любая помощь?

1 Ответ

0 голосов
/ 19 марта 2020

Вы получаете эту ошибку, потому что ваш ввод имеет тип uint8, а ваша сеть имеет значения с плавающей точкой. Вы просто должны привести свои данные к типу данных с плавающей точкой. Здесь, измените свою часть данных на это:

import numpy as np
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = expand_dims(x_train, -1)
x_test = expand_dims(x_test, -1)
x_train = x_train.astype(np.float32)
x_test = x_test.astype(np.float32)

Странно, я не получил ошибку, но я должен был, и это решение.

...