Я пытаюсь отобразить 9 изображений из моего набора проверки вместе с классом, который предсказал моя модель, но я получаю ошибку из-за элемента изменения формы plt.imshow()
. Количество пикселей и каналов для моих изображений (128, 128, 3)
(RGB). Я попытался изменить размер формы на (128, 128, 1)
и (128, 128, 3)
и (1, 128, 128)
, и ни один из этих вариантов не работает. Как узнать, какими должны быть эти числа, чтобы plt.imshow () работал успешно? Я знаю, что есть связанный вопрос StackOverflow, но ответы на эти посты мне не помогли.
target_size=(128,128) # target pixel size of each image
batch_size = 20 # the number of images to load per iteration
# configure a data generator which will rescale the images and create a training
# and test split where the test set is 10% of the data
data_gen_3 = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, validation_split=0.1)
val_img = data_gen_3.flow_from_directory(data_path,
subset='validation',
color_mode='rgb',
target_size=target_size,
batch_size=batch_size,
class_mode='categorical')
# get a sample of 20 (batch_size) validation images
sample_imgs_val, sample_labels_val = next(val_img)
# predict the class for the sample val images using the final model called "convnet"
X_pred_class = convnet.predict(sample_imgs_val)
# get the most likely class number for the prediction for each image
predicted_classes = np.argmax(X_pred_class, axis=1)
# display 9 of the images along with the predicted class
for img in range(9):
plt.subplot(3, 3, img + 1, frameon=False)
plt.imshow( np.reshape(sample_imgs_val[img],(128,128)) )
plt.title(predicted_classes[img])
plt.show()