Я пытаюсь классифицировать изображения как «безопасные» или «небезопасные» для этого, используя также обученную модель. Функция загрузки изображений в массивы numpy для передачи в model.predict
Входные данные:
1) image_paths: список путей к изображениям для загрузки
2) image_size: size в какие изображения следует изменить размер
Вывод:
1) loaded_images: загруженные изображения, на которых модель keras может выполнять прогнозы
2) loaded_image_indexes: пути изображений, которые выполняет функция способен обрабатывать
Я постоянно получаю сообщение об ошибке:
ValueError: Tensor Tensor("predictions/Softmax:0", shape=(?, 2), dtype=float32) is not an element of this graph.
Я тоже прикрепляю свой код:
import keras
from keras import backend as K
import numpy as np
import tensorflow as tf
graph = tf.get_default_graph()
class NudityFilter:
def __init__(self):
self.model_path = "./classifier_model"
def load_images(self, image_paths, image_size):
K.clear_session()
loaded_images = []
loaded_image_paths = []
for i, img_path in enumerate(image_paths):
try:
image = keras.preprocessing.image.load_img(img_path, target_size=image_size)
image = keras.preprocessing.image.img_to_array(image)
image /= 255
loaded_images.append(image)
loaded_image_paths.append(img_path)
except Exception as ex:
utils.logger.exception("__ERROR__ while loading image(s) " + str(ex))
return np.asarray(loaded_images), loaded_image_paths
def classify_obsence(self, image_paths=[]):
'''
inputs:
image_paths: list of image paths or can be a string too (for single image)
'''
try:
#Before prediction
K.clear_session()
batch_size = 32
image_size = (256, 256)
categories = ["unsafe", "safe"]
if isinstance(image_paths, str):
image_paths = list(image_paths)
loaded_images, loaded_image_paths = self.load_images(image_paths, image_size)
if not loaded_image_paths:
return {}
nsfw_model = keras.models.load_model(self.model_path)
global graph
with graph.as_default():
model_preds = nsfw_model.predict(loaded_images, batch_size=batch_size)
Здесь я получаю ValueError, как только вызывается функция nsfw_model.predict()
. Что я должен изменить / сделать, чтобы не получить эту ошибку. Когда я запускаю тот же код в GoogleColab / JupyterNotebook, он работает без ошибок. Почему это происходит?