Я хочу попробовать GridsearchCV на моей модели, мой импорт:
from keras import models
from keras import layers
from keras import regularizers
from sklearn.model_selection import GridSearchCV
from keras.wrappers.scikit_learn import KerasClassifier
мой код:
def build_model(X_train = X_train,neurons=4,optimizer='Adam'):
model = models.Sequential()
model.add(layers.Dense(X_train.shape[1], kernel_regularizer=regularizers.l2(0.001),
activation='relu', input_shape=(X_train.shape[1],)))
model.add(layers.BatchNormalization())
model.add(layers.Dense(neurons, kernel_regularizer=regularizers.l2(0.001), activation='relu'))
model.add(layers.BatchNormalization())
model.add(layers.Dense(1, activation='sigmoid'))
model.compile(optimizer=optimizer,loss='binary_crossentropy',metrics=['accuracy'])
return build_model
model = KerasClassifier(build_fn=build_model, verbose=1)
# define the grid search parameters
batch_size = [16, 32, 64]
epochs = [50, 100]
param_grid = dict(batch_size=batch_size, epochs=epochs)
# search the grid
grid = GridSearchCV(estimator=model,
param_grid=param_grid,
cv=10,
verbose=2)
grid_result = grid.fit(X_train, y_train)
, но я получаю ошибку, как показано ниже:
AttributeError Traceback (most recent call last)
<ipython-input-93-2eb813d3aab7> in <module>
12 verbose=2) # include n_jobs=-1 if you are using CPU
13
---> 14 grid_result = grid.fit(X_train, y_train)
15
16 print(model)
/anaconda3/envs/lance/lib/python3.7/site-packages/keras/wrappers/scikit_learn.py in fit(self, x, y, sample_weight, **kwargs)
208 if sample_weight is not None:
209 kwargs['sample_weight'] = sample_weight
--> 210 return super(KerasClassifier, self).fit(x, y, **kwargs)
211
212 def predict(self, x, **kwargs):
/anaconda3/envs/lance/lib/python3.7/site-packages/keras/wrappers/scikit_learn.py in fit(self, x, y, **kwargs)
141 self.model = self.build_fn(**self.filter_sk_params(self.build_fn))
142
--> 143 loss_name = self.model.loss
144 if hasattr(loss_name, '__name__'):
145 loss_name = loss_name.__name__
AttributeError: 'function' object has no attribute 'loss'
Я не могу понять, в чем ошибка, и я уверен, что обработка данных правильная, потому что она идет хорошо без поиска по сетке, я сделал что-то не так?