Я пытаюсь запустить линейное ядро SVM, используя сгенерированный набор данных. В моем наборе данных 5000 строк и 4 столбца:
CL_scaled.head()[screenshot of data frame][1]
Я разбил данные на 20% теста и 80% обучения:
train, test = train_test_split(CL_scaled, test_size=0.2)
и получил форму (4000,4)для поезда и (1000,4) для теста
Однако, когда я запускаю svm на данных обучения и тестирования, я получаю следующую ошибку:
svclassifier = SVC(kernel='linear', C = 5)
svclassifier.fit(train, test)
ValueError Traceback (most recent call last)
<ipython-input-81-4c4a7bdcbe85> in <module>
----> 1 svclassifier.fit(train, test)
~/anaconda3/lib/python3.7/site-packages/sklearn/svm/base.py in fit(self, X, y, sample_weight)
144 X, y = check_X_y(X, y, dtype=np.float64,
145 order='C', accept_sparse='csr',
--> 146 accept_large_sparse=False)
147 y = self._validate_targets(y)
148
~/anaconda3/lib/python3.7/site-packages/sklearn/utils/validation.py in check_X_y(X, y, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, multi_output, ensure_min_samples, ensure_min_features, y_numeric, warn_on_dtype, estimator)
722 dtype=None)
723 else:
--> 724 y = column_or_1d(y, warn=True)
725 _assert_all_finite(y)
726 if y_numeric and y.dtype.kind == 'O':
~/anaconda3/lib/python3.7/site-packages/sklearn/utils/validation.py in column_or_1d(y, warn)
758 return np.ravel(y)
759
--> 760 raise ValueError("bad input shape {0}".format(shape))
761
762
ValueError: bad input shape (1000, 4)
Может кто-нибудь, пожалуйста, дайте мнезнаете, что не так с моим кодом или данными? Заранее спасибо!
train.head()
0 1 2 3
2004 1.619999 1.049560 1.470708 -1.323666
1583 1.389370 -0.788002 -0.320337 -0.898712
1898 -1.436903 0.994719 0.326256 0.495565
892 1.419123 1.522091 1.378514 -1.731400
4619 0.063095 1.527875 -1.285816 -0.823347
test.head()
0 1 2 3
1118 -1.152435 -0.484851 -0.996602 1.617749
4347 -0.519430 -0.479388 1.483582 -0.413985
2220 -0.966766 -1.459475 -0.827581 0.849729
204 1.759567 -0.113363 -1.618555 -1.383653
3578 0.329069 1.151323 -0.652328 1.666561
print(test.shape)
print(train.shape)
(1000, 4)
(4000, 4)