Я пытаюсь обучить модель inceptionv3 в кератах.
Мой набор данных предварительно обработан в форме 229, 229, 3
.
print(data.shape)
print(type(data))
print(type(data[0]))
выход
(1458, 229, 229, 3)
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
Я инициализирую свою модель следующим образом
import os, sys
from keras.optimizers import SGD
from keras.applications import InceptionV3
model = InceptionV3()
# copile model
opt = SGD(lr=0.05)
model.compile(loss="categorical_crossentropy", optimizer=opt,
metrics=["accuracy"])
Вызываю model.fit
# train the network
print("[INFO] training network...")
H = model.fit(train_x, train_y, validation_data=(test_x, test_y),
batch_size=batch_size, epochs=num_of_epochs, verbose=1)
Затем я получаю эту ошибку. Я не понимаю, потому что размеры правильные.
[INFO] training network...
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
in
2 print("[INFO] training network...")
3 H = model.fit(train_x, train_y, validation_data=(test_x, test_y),
----> 4 batch_size=batch_size, epochs=num_of_epochs, verbose=1)
5
6 model.save(model_save_path)
~/anaconda3/lib/python3.7/site-packages/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
950 sample_weight=sample_weight,
951 class_weight=class_weight,
--> 952 batch_size=batch_size)
953 # Prepare validation data.
954 do_validation = False
~/anaconda3/lib/python3.7/site-packages/keras/engine/training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, check_array_lengths, batch_size)
749 feed_input_shapes,
750 check_batch_axis=False, # Don't enforce the batch size.
--> 751 exception_prefix='input')
752
753 if y is not None:
~/anaconda3/lib/python3.7/site-packages/keras/engine/training_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
136 ': expected ' + names[i] + ' to have shape ' +
137 str(shape) + ' but got array with shape ' +
--> 138 str(data_shape))
139 return data
140
ValueError: Error when checking input: expected input_1 to have shape (299, 299, 3) but got array with shape (229, 229, 3)
РЕДАКТИРОВАТЬ
batch_size = 32
Как были изменены размеры изображений
import imutils
import cv2
class AspectAwarePreprocessor:
"""
CONTRUCTOR
witdh : desired width
height : desired height
inter : interpolation method used when resizing the image
"""
def __init__(self,width,height,inter=cv2.INTER_AREA):
self.width = width
self.height = height
self.inter = inter
"""
image : image to be preprocessed
"""
def preprocess(self,image):
# Get wdith and height of image
(h, w) = image.shape[:2]
dW = 0
dH = 0
# if width is the shorter dimension, resize image by width and crop height
if w < h:
image = imutils.resize(image, width=self.width,
inter=self.inter)
dH = int((image.shape[0] - self.height) / 2.0)
# if height is the shorter dimension, resize image by height and crop width
else:
image = imutils.resize(image, height=self.height,
inter=self.inter)
dW = int((image.shape[1] - self.width) / 2.0)
# re-grab the width and height and use the deltas to crop the center of the image:
(h, w) = image.shape[:2]
image = image[dH:h - dH, dW:w - dW]
# our image target image dimensions may be off by ± one pixel; therefore, we make a call to cv2.resize to
# ensure our output image has the desired width and height.
return cv2.resize(image, (self.width, self.height),
interpolation=self.inter)