Keras-TF model.predict () дал мне неверный результат - PullRequest
0 голосов
/ 03 октября 2019

Я только что обучил свою модель CNN, у которой есть детали слоя ниже:

model = Sequential()
model.add(Conv2D(100, (3, 3), input_shape=(100, 100, 1), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(120, (5, 5), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(140, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(rate=0.5))
model.add(Flatten())
model.add(Dense(200, activation='relu'))
model.add(Dropout(rate=0.5))
model.add(Dense(100, activation='relu'))
model.add(Dropout(rate=0.5))
model.add(Dense(num_classes, activation='softmax'))
model.compile(Adam(lr=0.001), loss='categorical_crossentropy', metrics=['accuracy'])

Моя модель имеет точность приблизительно 95.xxx% при проверке / проверке результатов кера.

И когда я импортировал файл h5, чтобы проверить предсказание изображения, я сделал:

def grayscale(img):
  img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  return img

def equalize(img):
  img = cv2.equalizeHist(img)
  return img

def preprocessing(img):
  gray = grayscale(img)
  eq = equalize(gray)
  return eq / 255

image = cv2.imread("sample.jpg")
print("Loading model...")
model = load_model("model.h5")
classes = pd.read_csv('dataset.csv', header=0, usecols=['BananaOrNot']).values

image = preprocessing(image)
image = image.reshape(100, 100, 1)
image = np.expand_dims(image, axis=0)
y_prob = model.predict(image)
class_idx = y_prob.argmax(axis=-1)[0]
print(classes[class_idx][0]) # it produced wrong result (should be 'not banana', got 'banana')

Я проверил с образцом пустое белое изображение, и предсказание дало мне Банан результат, хотя он должен быть Не банан точно. Как это может быть неправильно в этом очень простом тестовом примере? Или что-то не так с моим model.predict(img) входным изображением?

1 Ответ

0 голосов
/ 03 октября 2019

Поскольку в вашем наборе данных нет класса с именем Not banana , он дает только те классы, которые вы обучили в своей сети.

Для того, чтобы такое поведение было в вашей моделиВы должны установить класс , например: Фон на вашей модели и обучить его определенному количеству Фон изображений с надписью. В противном случае это даст результат, который наиболее вероятен в вашем выходном векторе , т.е. наиболее вероятный класс, даже если это пустое изображение.

ИЛИ

Вы можете обучить классификацииМодель наряду с локализацией и классификацией встречается только в локализованных объектах, и ваша проблема может быть исправлена.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...