Я пытаюсь GridsearchCV
, но я хотел бы иметь некоторые исключения в param grid
. Вот мой код поиска по сетке:
from sklearn.model_selection import GridSearchCV
from keras.wrappers.scikit_learn import KerasClassifier
def create_model(input_dim=25, activation='relu', units=100, optimizer = 'adam', init='he_normal', dropout_rate=0.33):
model = Sequential()
model.add(Dense(input_dim=input_dim,
units=units,
kernel_initializer=init,
activation=activation))
model.add(Dropout(dropout_rate))
model.add(Dense(1, kernel_initializer=init, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
return model
model = KerasClassifier(build_fn=create_model, epochs=10, batch_size=64, verbose=1)
#lr = [0.01, 0.001, 0.0001] # [x/100 for x in range(1, 10)] #learning rate for optimizer
units = [int(train_X.shape[1]/2), train_X.shape[1], train_X.shape[1]*2, train_X.shape[1]*3]
batch_size = [32, 64, 128, 256]
optimizer = ['SGD', 'RMSprop', 'Adagrad', 'Adadelta', 'Adam', 'Adamax', 'Nadam']
epochs = [50, 100, 200]
init = ['uniform', 'lecun_uniform', 'normal', 'zero', 'glorot_normal', 'glorot_uniform', 'he_normal', 'he_uniform']
activation = ['softmax', 'softplus', 'softsign', 'relu', 'tanh', 'sigmoid', 'hard_sigmoid', 'linear']
dropout_rate = [0.2, 0.3, 0.4, 0.5]
cv = [(slice(None), slice(None))]
param_grid = dict(units=units, batch_size=batch_size, optimizer=optimizer, epochs=epochs,
init=init, activation=activation, dropout_rate=dropout_rate)
grid = GridSearchCV(cv=cv, estimator=model, param_grid=param_grid, n_jobs=1)
grid_result = grid.fit(train_X, train_y, validation_data=(valid_X, valid_y))
когда я запускаю его код, хотя KerasClassifier
имеет параметр epochs=10
, сетка никогда не запускает 10 эпох, но работает на [50, 100, 200]
, который я указал в сетке. Это как epochs=10
было перезаписано.
Теперь я хочу использовать разные activation functions
в первом слое, но оставить Sigmoid
в выходном слое. Боюсь, что параметр activation='sigmoid'
будет перезаписан моим параметром activation = ['softmax', 'softplus', 'softsign', 'relu', 'tanh', 'sigmoid', 'hard_sigmoid', 'linear']
, который поступит из сетки?
Надеюсь, это имеет для вас больше смысла.