Я пытаюсь выполнить обучение с помощью InceptionV3 для набора данных MNIST.
План состоит в том, чтобы прочитать набор данных MNIST, изменить размеры изображений, а затем использовать их для обучения следующим образом:
import numpy as np
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import tensorflow.compat.v2 as tf
import tensorflow.compat.v1 as tfv1
from tensorflow.python.keras.applications import InceptionV3
tfv1.enable_v2_behavior()
print(tf.version.VERSION)
img_size = 299
def preprocess_tf_image(image, label):
image = tf.image.grayscale_to_rgb(image)
image = tf.image.resize(image, [img_size, img_size])
return image, label
#Acquire MNIST data
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
#Convert data to [0,1] range
x_train, x_test = x_train / 255.0, x_test / 255.0
#Add extra dimension to images so that they can be converted to RGB
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape (x_test.shape[0], 28, 28, 1)
x_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
x_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
#Convert images to RGB space and resize
x_train = x_train.map(preprocess_tf_image)
x_test = x_test.map(preprocess_tf_image)
img_shape = (img_size, img_size, 3)
#Get trained model, but leave off the head
base_model = InceptionV3(input_shape = img_shape, weights='imagenet', include_top=False)
base_model.trainable = False
#Make a model with a new head
model = tf.keras.Sequential([
base_model,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
#Compile model
model.compile(
optimizer='adam', #tf.keras.optimizers.RMSprop(lr=BASE_LEARNING_RATE),
loss='binary_crossentropy',
metrics=['accuracy']
)
model.fit(x_train, epochs=5)
model.evaluate(x_test)
Но, когда я запускаю это, все останавливается на model.fit()
с ошибкой:
ValueError: Ошибка при проверке ввода: ожидалось, что inception_v3_input имеет 4 измерения, но получил массив сформа (299, 299, 3)
Что происходит?