Я пытаюсь рассчитать ROC и AUC для модели SVM, которую я строю.Я следую коду из этого примера sklearn .Одним из требований является то, что выходные метки y
должны быть преобразованы в двоичную форму.Я делаю это с помощью создания MultiLabelBinarizer
и кодирую все метки, что прекрасно работает.Однако это создает ndarray (n_samples, n_features).Функция classifier.fit(X, y)
предполагает y.shape = (n_samples)
.Я хочу по существу «смять» столбцы y
вместе, чтобы y [0] [0] возвращал весь вектор объекта, V
, а не только первое значение V
.
Вот мой код:
enc = MultiLabelBinarizer()
print("Encoding data...")
# Fit the encoder onto all possible data values
print(pandas.DataFrame(enc.fit_transform(df["present"] + df["member"].apply(str).apply(lambda x: [x])),
columns=enc.classes_, index=df.index))
X, y = enc.transform(df["present"]), list(df["member"].apply(str))
print("Training svm...")
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.5, random_state=0)
y_train = enc.transform([[x] for x in y_train]) # Strings to 1HotVectors
svc = svm.SVC(C=1.1, kernel="linear", probability=True, class_weight='balanced')
svc.fit(X_train, y_train) # y_train should be 1D but isn't
Исключение, которое я получаю:
Traceback (most recent call last):
File "C:/Users/SawyerPC/PycharmProjects/DiscordSocialGraph/encode_and_train.py", line 129, in <module>
enc, clf, split_data = encode_and_train(df)
File "C:/Users/SawyerPC/PycharmProjects/DiscordSocialGraph/encode_and_train.py", line 57, in encode_and_train
svc.fit(X_train, y_train) # TODO y_train needs to be flattened to (n_samples,)
File "C:\Users\SawyerPC\Anaconda3\lib\site-packages\sklearn\svm\base.py", line 149, in fit
X, y = check_X_y(X, y, dtype=np.float64, order='C', accept_sparse='csr')
File "C:\Users\SawyerPC\Anaconda3\lib\site-packages\sklearn\utils\validation.py", line 547, in check_X_y
y = column_or_1d(y, warn=True)
File "C:\Users\SawyerPC\Anaconda3\lib\site-packages\sklearn\utils\validation.py", line 583, in column_or_1d
raise ValueError("bad input shape {0}".format(shape))
ValueError: bad input shape (5000, 10)