Я построил классификатор изображений, как показано ниже:
import tensorflow as tf
from tensorflow.keras.applications.mobilenet import preprocess_input
image_width, image_height = 224, 224
input_shape = (image_width, image_height, 3)
self.model = tf.keras.Sequential()
pretrained_layer = tf.keras.applications.mobilenet.MobileNet(
weights="imagenet", include_top=False, input_shape=self.input_shape
)
self.model.add(pretrained_layer)
self.model.add(tf.keras.layers.GlobalAveragePooling2D())
self.model.add(tf.keras.layers.Dense(256, activation="relu"))
self.model.add(tf.keras.layers.Dropout(0.5))
self.model.add(tf.keras.layers.Dense(128, activation="relu"))
self.model.add(tf.keras.layers.Dropout(0.2))
self.model.add(tf.keras.layers.Dense(len(DATA_LABELS), activation="sigmoid"))
self.model.compile(
optimizer=tf.keras.optimizers.Adam(0.0005),
loss="binary_crossentropy",
metrics=["accuracy"],
)
У меня также была функция прогнозирования, которая ожидает ввод в виде numpy массива
def predict(self, image):
"""Predict the labels for a single screenshot
image -- The numpy array of the image to classify
"""
img = np.expand_dims(image, axis=0)
img = preprocess_input(img)
prediction = self.model.predict(img, batch_size=1)
Теперь я получаю изображение, которое это массив 1d numpy (23280,), когда я передаю это в модель прогнозирования, я получаю сообщение об ошибке, как показано ниже:
prediction = model.predict(np.asarray(bytearray(ss_read))) # np.asarray(bytearray(ss_read)) is 1d numpy array (23280,)
ValueError: Error when checking input: expected mobilenet_1.00_224_input to have 4 dimensions, but got array with shape (1, 23280)
так, как я могу изменить этот массив numpy и сделать это готово для предсказателя? Я думаю, что я могу сделать что-то вроде np.reshape(np.asarray(bytearray(ss_read)), (image_width, image_height, 3))
, но общий объем данных не совсем то же самое после преобразования в этом случае (224 * 224 * 3 = 150528> 23280). я должен сделать что-то вроде этого вместо np.reshape(np.asarray(bytearray(ss_read)), (image_width, -1, 3))
?