Отличная идея: просто попробуйте и посмотрите, как это происходит. Это очень зависит от вашего набора данных и выбора модели, как вы уже догадались. Трудно предсказать, как будет вести себя добавление этого типа функции, как и любая другая функция разработки. Но будьте осторожны, в некоторых случаях это даже не улучшает вашу производительность. См. Тест ниже, где производительность фактически снижается, с набором данных Iris:
import numpy as np
from sklearn.cluster import KMeans
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import load_iris
from sklearn.svm import SVC
from sklearn import metrics
# load data
iris = load_iris()
X = iris.data[:, :3] # only keep three out of the four available features to make it more challenging
y = iris.target
# split train / test
indices = np.random.permutation(len(X))
N_test = 30
X_train, y_train = X[indices[:-N_test]], y[indices[:-N_test]]
X_test, y_test = X[indices[N_test:]], y[indices[N_test:]]
# compute a clustering method (here KMeans) based on available features in X_train
kmeans = KMeans(n_clusters=3, random_state=0).fit(X_train)
new_clustering_feature_train = kmeans.predict(X_train)
new_clustering_feature_test = kmeans.predict(X_test)
# create a new input train/test X with this feature added
X_train_with_clustering_feature = np.column_stack([X_train, new_clustering_feature_train])
X_test_with_clustering_feature = np.column_stack([X_test, new_clustering_feature_test])
Теперь давайте сравним две модели, которые выучились либо только на X_train
, либо на X_train_with_clustering_feature
:
model1 = SVC(kernel='rbf', gamma=0.7, C=1.0).fit(X_train, y_train)
print(metrics.classification_report(model1.predict(X_test), y_test))
precision recall f1-score support
0 1.00 1.00 1.00 45
1 0.95 0.97 0.96 38
2 0.97 0.95 0.96 37
accuracy 0.97 120
macro avg 0.97 0.97 0.97 120
weighted avg 0.98 0.97 0.97 120
И другая модель:
model2 = SVC(kernel='rbf', gamma=0.7, C=1.0).fit(X_train_with_clustering_feature, y_train)
print(metrics.classification_report(model2.predict(X_test_with_clustering_feature), y_test))
0 1.00 1.00 1.00 45
1 0.87 0.97 0.92 35
2 0.97 0.88 0.92 40
accuracy 0.95 120
macro avg 0.95 0.95 0.95 120
weighted avg 0.95 0.95 0.95 120