У меня есть приложение Pyspark, которое в основном загружает файлы изображений где-то s3 и извлекает функции из этих файлов изображений с помощью керасов.
Вот весь поток: -
1. Download images from s3 using.
s3_files_rdd = sc.binaryFiles(s3_path) ## [('s3n://..',bytearray)]
2. Then convert the above byte inside the rdd to image object.
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from io import BytesIO
def convert_binary_to_image_obj(obj):
img = mpimg.imread(BytesIO(obj), 'jpg')
return img
images_rdd = s3_files_rdd.map(lambda x: (x[0], convert_binary_to_image_obj(x[1])))
3. Now pass the images_rdd to another function to extract features using keras vgg16 model.
def initVGG16():
model = VGG16(weights='imagenet', include_top=True)
return Model(inputs=model.input, outputs=model.get_layer("fc2").output)
def extract_features(img):
img_data = image.img_to_array(img)
img_data = np.expand_dims(img_data, axis=0)
img_data = preprocess_input(img_data)
vgg16_feature = initVGG16().predict(img_data)[0]
return vgg16_feature
features_rdd = images_rdd.map(lambda x: (x[0], extract_features(x[1])))
Но когда я пытаюсь использовать приложение, выдается следующее сообщение об ошибке: -
ValueError: Error when checking input: expected input_1 to have shape (224, 224, 3) but got array with shape (300, 200, 3)
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:330)
at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:470)
at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:453)
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:284)
at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
at scala.collection.Iterator$class.foreach(Iterator.scala:893)
Я знаю, что здесь ошибка в функции extract_features, которая ожидает, что размер изображения будет 224,224,3, чего сейчас нет. Потому что я не сохраняю образ на свой локальный диск. Я напрямую конвертирую, используя matplotlib lib, в объект изображения после загрузки с s3.
Как решить эту проблему? В основном я хочу загрузить изображение из s3, а затем изменить его размер в памяти, как работает image.load_img(image_path, target_size=(224, 224))
, а затем передать этот объект изображения в мою функцию extract_features.