Я пытаюсь найти гиперпараметры для моей модели Keras, которая содержит пользовательскую функцию потерь, с помощью gridsearchcv .
Когда я запускаю это вручную (за пределами gridsearch), оно работает.Однако, если я запускаю gridsearchcv на нем и выполняю это параллельно, кажется, что существует некоторая проблема с сериализацией, которую выполняет бэкэнд Koby, при установке параметра n_jobs=-1
.
Вот как я настраиваюМоя пользовательская функция потерь:
from keras import backend as K
def my_loss(y_true, y_pred):
# perform some calculations using K
# ...
return 1 - K.mean(score)
Это моя модельная функция, которая берет некоторые параметры из поиска по сетке:
from keras.models import Model
from keras.layers import Dense, Dropout, Input
def create_model(hidden_layers, activation, last_activation, optimizer, drop_rate, initializer, units, input_dim):
# some layer definitions of inputs and outputs
model = Model(inputs, outputs)
model.compile(loss=my_loss, optimizer=optimizer, metrics=['binary_accuracy'])
return model
Наконец, как я собрал ее:
from sklearn.model_selection import GridSearchCV, StratifiedKFold
from keras.wrappers.scikit_learn import KerasClassifier
skfs = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
parameters = {'hidden_layers': [...],
'activation':[...],
'last_activation': [...],
'optimizer':[...],
'drop_rate':[...],
'initializer':[...],
'units':[...],
'input_dim': [...]
}
model = KerasClassifier(build_fn=create_model, epochs=100, batch_size=256, verbose=0)
classifier = GridSearchCV(estimator=model, param_grid=parameters, cv=skfs, scoring=score, verbose=1, n_jobs=-1, return_train_score=True)
results = classifier.fit(X, y, callbacks=early_stopping)
Когда я выполняю приведенный выше код, я получаю исключение о том, что функция my_loss
неизвестна.Однако, если я изменю его на (см. n_jobs
параметр):
classifier = GridSearchCV(estimator=model, param_grid=parameters, cv=skfs, scoring=score, verbose=1, n_jobs=1, return_train_score=True)
, это сработает.Я считаю, что должен каким-то образом пропагандировать свою собственную функцию.Как я могу это сделать?