Ошибка при попытке классифицировать 3D-изображения с использованием Наивного Байеса - PullRequest
0 голосов
/ 27 апреля 2020

Я создал алгоритм сверточных нейронных сетей для классификации изображений, и теперь я хочу сделать алгоритм Наивного Байеса для сравнения. Мои изображения 3D, и я думаю, что это причина ошибки, которую я получаю.

Ошибка:

raise ValueError("bad input shape {0}".format(shape))
ValueError: bad input shape (1776, 3)

Мой код:

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import GaussianNB
import numpy as np

much_data = np.load('muchdata-50-50-30-normalizado.npy', allow_pickle=True)
X = [data[0] for data in much_data]
y = [data[1] for data in much_data]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
gnb = GaussianNB()
y_pred = gnb.fit(X_train, y_train).predict(X_test)
print("Number of mislabeled points out of a total %d points : %d" % (X_test.shape[0], (y_test != y_pred).sum()))

Мой X [0] имеет следующий формат:

  [[[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]
  ...
  [[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]]

И мой y [0]:

[0 1 0]

Если кто-то может помочь мне понять, что я делаю неправильно, это будет будь очень полезным!

Большое спасибо!

1 Ответ

1 голос
/ 27 апреля 2020

Когда вы смотрите на y[0], кажется, что у вас есть 3 класса в формате быстрого кодирования. Алгоритмы машинного обучения sklearn в целом не принимают целевые значения в формате с горячим кодированием. Кроме того, вход (X) для модели должен иметь форму (no_samples, no_features). Следовательно, вы должны сгладить трехмерные изображения.

  1. Избавьтесь от одноразовых кодировок в цели (y) и получите одномерный массив в формате (no_samples,). Вы можете достичь этого, определив 3 класса как 1, 2, 3.
  2. Свести изображения. Вы можете сделать это с X = [data[0].flatten() for data in much_data]
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...