Ошибка при построении данных обучения и тестирования набора данных scikit-learn - PullRequest
0 голосов
/ 11 октября 2019

Я пытаюсь построить данные обучения и теста из набора данных scikit-learn.

import sys, os
import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets
plt.switch_backend('agg')

%matplotllib inline

diabetes = datasets.load_diabetes()
diabetes_X = diabetes.data[:, np.newaxis, 2]
diabetes_X_train = diabetes_X[:-20]
diabetes_X_test = diabetes_X[-20:]
diabetes_y_train = np.matrix(diabetes.target[:-20]).T
diabetes_y_test = np.matrix(diabetes.target[-20:]).T

plt.scatter(diabetes_X_train, diabetes_y_train,  color='black')
plt.scatter(diabetes_X_test, diabetes_y_test,  color='red')

, но у меня появляется следующая ошибка:

ValueError: all the input array dimensions for the concatenation axis must match exactly, but along dimension 0, the array at index 0 has size 422 and the array at index 1 has size 1

Я проверил формуМатрицы и тренировочные данные имеют (422,1) и тестовые данные (20,1). Что вызывает эту ошибку?

1 Ответ

1 голос
/ 11 октября 2019

plt.scatter ожидает построить два набора данных одинаковой формы друг против друга. Если они не 1D, они будут сплющены. Нет смысла выравнивать X в задаче машинного обучения.

Проверьте размеры X_train и y_train. Вы увидите, что они не совместимы. Это 2D-график, который вы делаете, вы можете наносить только один набор чисел против другого. X - это матрица: каждая строка - это набор чисел.

Итак, вы можете сделать это:

import numpy as np
import matplotlib.pyplot as plt

x, y = np.random.random((422, 1)), np.random.random((422, 1))
plt.scatter(x, y)

Но вы не можете сделать это:

X, y = np.random.random((422, 10)), np.random.random((422, 1))
plt.scatter(X, y)

По сути, это то, что вы пытаетесь сделать. (Между прочим, я не думаю, что вы хотите транспонировать y.)

Так что это должно работать для вас:

plt.scatter(diabetes_X_train[:, 0], diabetes_y_train)

Но это показывает только связь с одной особенностьюX.

Если вы просто пытаетесь изучить данные, я рекомендую проверить seaborn.pairplot. Это идеально подходит для такого рода вещей.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...