Обработайте данные обучения / перекрестной проверки по-разному перед передачей в GridSearchCV от sklearn - PullRequest
0 голосов
/ 22 мая 2019

Я тренирую ванильную нейронную сеть в keras. Чтобы выбрать лучшие гиперпараметры, я хотел бы использовать scikit-learn API (документация здесь ). Однако, поскольку я заинтересован в прогнозировании вероятности для каждого ярлыка, я использую следующий нетрадиционный способ обучения модели:

  • Метки в моем наборе данных являются неотрицательными целыми числами. Для обучения модели я повторяю выборки на основе этих значений, чтобы получить правильную функцию потерь. Таким образом, образец с метками [1, 3, 0] заменяется четырьмя образцами [1, 0, 0], [0, 1, 0], [0, 1, 0], [0, 1, 0].
  • На тренировке модели с этими повторными выборками, чтобы получить вектор вероятности (который я ожидаю получить), я преобразовываю вышеупомянутые неотрицательные целые числа, найденные в метках набора данных, в вектор вероятности и использую categorical cross-entropy как функция потерь. Метки в образце в приведенном выше примере будут заменены на [0.25, 0.75, 0].

Я хотел бы реализовать настройку гиперпараметра в GridSearchCV программы scikit-learn.

Вот часть Keras, которая отлично работает:

keras_model = get_model()
# get_model() returns a compiled Keras model
train = repeat_samples(train_data)
test = to_prob_vec(test_data)
keras_model.fit(train[x_cols], train[y_cols])
keras_model.evaluate(test[x_cols], test[y_cols])

И часть склеарна, которая нуждается в соответствующей модификации:

sk_model = KerasRegressor(build_fn=get_model())
param_grid = dict(epochs=[10, 20, 30])
grid = GridSearchCV(estimator=sk_model, param_grid=param_grid)
grid_result = grid.fit(<what here>)
# Following is incorrect because during cross-validation 
# it has dummy-coded variables instead of probability vectors:
# grid_result = grid.fit(train[x_cols], train[y_cols])

Я подозреваю, что соответствующее изменение параметра cv в GridSearchCV ( документация ) поможет мне достичь этого. Я не могу понять, как.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...