Почему CNN в Python работает намного хуже, чем в Matlab? - PullRequest
4 голосов
/ 04 марта 2020

Я обучил CNN в Matlab 2019b, который выполняет двоичную классификацию. Когда этот CNN был протестирован в наборе тестовых данных, он получал точность около 95%. Я использовал функцию exportONNXNetwork , чтобы я мог реализовать свой CNN в Tensorflow, Керас. Это код, который я использую для использования файла ONNX в кератах:

import onnx
from onnx_tf.backend import prepare
import numpy as np
from numpy import array
from IPython.display import display
from PIL import Image

onnx_model = onnx.load("model.onnx")
tf_rep = prepare(onnx_model)
img = Image.open("image.jpg").resize((224,224))
img = array(img).reshape(1,3,224,224)
img = img.astype(np.uint8)

classification = tf_rep.run(img)
print(classification)

Когда этот код python был протестирован на том же наборе данных , он классифицировал почти все как класс 0 с несколькими делами класса 1. Я не уверен, почему это происходит.

1 Ответ

2 голосов
/ 04 марта 2020

На первый взгляд, я думаю, что вам нужно переставить оси изображения, а не изменить его форму:

img = Image.open("image.jpg").resize((224,224))
img = array(img).transpose(2, 0, 1)
img = np.expand_dims(img, 0)

Изображение, полученное из PIL, имеет последний формат каналов, т. Е. Тензор формы (height, width, channels) в данном случае (224, 224, 3). Ваша модель ожидает ввод в формате первого канала, т. Е. Тензор формы (channels, height, width), в данном случае (3, 224, 224).

. Вам необходимо переместить последнюю ось вперед. Если вы используете изменение формы, NumPy будет проходить массив в порядке C (индекс последней оси изменяется быстрее всего), что означает, что ваше изображение будет зашифровано. Это легче понять на примере:

>>> img = np.arange(48).reshape(4, 4, 3)
>>> img[0, 0, :]
array([0, 1, 2])

Значения RGB пикселя (0, 0) равны (0, 1, 2). Если вы используете np.transpose(), это сохраняется:

>>> img.transpose(2, 0, 1)[:, 0, 0]
array([0, 1, 2])

Если вы используете изменение формы, ваше изображение будет зашифровано:

>>> img.reshape(3, 224, 224)[:, 0, 0]
array([0, 16, 32])
...