Я работаю над проблемой классификации, когда мне нужно предсказать класс текстовых данных. Мне нужно настроить гиперпараметры для моей модели классификации, для которой я собираюсь использовать GridSearchCV
. Мне также нужно сделать StratifiedKFold
, потому что мои данные несбалансированы. Я знаю, что GridSearchCV
внутренне использует StratifiedKFold
, если у нас есть мультиклассовая классификация.
Я прочитал здесь , что в случае TfidfVectorizer
мы применяем fit_transform
к обучающим данным и преобразуем только в тестовые данные.
Это то, что я сделал ниже, используя StratifiedKFold
.
skf = StratifiedKFold(n_splits=5, random_state=5)
for train_index, test_index in skf.split(X, y):
iteration = iteration+1
print(f"Iteration number {iteration}")
X_train, y_train = X.iloc[train_index], y.iloc[train_index]
X_test, y_test = X.iloc[test_index], y.iloc[test_index]
train_tfid = tfidf_vectorizer.fit_transform(X_train.values.astype('U'))
test_tfid = tfidf_vectorizer.transform(X_test.values.astype('U'))
svc_model = linear_model.SGDClassifier()
svc_model.fit(train_tfid, y_train.values.ravel())
Точность / f1, которую я получаю, не очень хорошая, поэтому подумал о настройке гиперпараметров с помощью GridSearchCV. В GridSearchCV мы делаем
c_space = np.logspace(-5, 8, 15)
param_grid = {'C': c_space}
# Instantiating logistic regression classifier
logreg = LogisticRegression()
# Instantiating the GridSearchCV object
logreg_cv = GridSearchCV(logreg, param_grid, cv = 5)
logreg_cv.fit(X, y)
По моему мнению, logreg_cv.fit(X, y)
будет внутренне разделить X на X_train
, X_test
k раз, а затем будет делать прогнозы, чтобы дать нам лучшую оценку.
В моем случае каким должен быть X? Если это X, который сгенерирован после fit_transform
, то внутри, когда X разделен на обучение и тестирование, тестовые данные прошли fit_transform
, но в идеале они должны преобразовываться только в go.
Меня беспокоит, что в моем случае внутри GridSearchCV
как я мог бы контролировать, что fit_transform
применяется только для данных обучения, а преобразование применяется к тестовым данным (данные проверки).
, потому что, если он внутренне применяет fit_transform к все данные, то это не лучшая практика.