plt: IndexError: слишком много индексов для массива - PullRequest
0 голосов
/ 24 февраля 2020

Я использую классификатор KNN для этого набора данных:

enter image description here

Я пытаюсь получить график, но я получил следующую ошибку:

Traceback (most recent call last):
  File "/home/ashutosh/Machine Learning A-Z Template Folder/Part 3 - Classification/Section 20 - Random Forest Classification/P14-Random-Forest-Classification/Random_Forest_Classification/rft.py", line 38, in <module>
    plt.scatter(x_set[y_set == j, 0], x_set[y_set == j, 1], c=ListedColormap(('red', 'green'))(i), label = j)
IndexError: too many indices for array

при выполнении следующего кода:

from matplotlib.colors import ListedColormap
x_set, y_set = x_train, y_train
x1, x2 = np.meshgrid(np.arange(start=x_set[:, 0].min() - 1, stop=x_set[:, 0].max() + 1, step = 0.01),
                     np.arange(start=x_set[:, 1].min() - 1, stop=x_set[:, 1].max() + 1, step = 0.01))
plt.contour(x1, x2, classifier.predict(np.array([x1.ravel(), x2.ravel()]).T).reshape(x1.shape),
            alpha = 0.75, cmap = ListedColormap(('red', 'green')))
plt.xlim(x1.min(), x1.max())
plt.ylim(x2.min(), x2.max())
for i, j in enumerate(np.unique(y_set)):
    plt.scatter(x_set[y_set == j, 0], x_set[y_set == j, 1], c=ListedColormap(('red', 'green'))(i), label=j)
plt.show()
...