Как использовать KerasClassifier с OneVsRestClassifier? - PullRequest
0 голосов
/ 05 февраля 2019

Я пытаюсь использовать KerasClassifier с OneVsRestClassifier, чтобы разбить проблему классификации нескольких классов на различные подзадачи двоичной классификации.

Это код, который я использую:

import keras as k
import scipy as sp
from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.model_selection import RandomizedSearchCV

def build_keras_model(loss = 'binary_crossentropy', metrics = ['accuracy'], optimiser = 'adam', 
                  learning_rate = 0.001, n_neurons = 30, n_layers = 1, n_classes = 2,
                  l1_reg = 0.001, l2_reg = 0.001, batch_norm = False, dropout = None, 
                  input_shape = (8,)):

    model = k.models.Sequential()

    model.add(k.layers.Dense(n_neurons, 
                             input_shape = input_shape,
                             kernel_regularizer = k.regularizers.l1_l2(l1 = l1_reg, l2 = l2_reg),
                             activation = 'relu'))
    if batch_norm is True:
        model.add(k.layers.BatchNormalization())
    if dropout is not None:
        model.add(k.layers.Dropout(dropout))

    i = 1   
    while i < n_layers:
        model.add(k.layers.Dense(n_neurons,
                                 kernel_regularizer = k.regularizers.l1_l2(l1 = l1_reg, l2 = l2_reg),
                                 activation = 'relu'))
        if batch_norm is True:
            model.add(k.layers.BatchNormalization())
        if dropout is not None:
            model.add(k.layers.Dropout(dropout))
        i += 1
    del i

    model.add(k.layers.Dense(1, activation = 'sigmoid'))

    if optimiser == 'adam':
        koptimiser = k.optimizers.Adam(lr = learning_rate)
    elif optimiser == 'adamax':
        koptimiser = k.optimizers.Adamax(lr = learning_rate)
    elif optimiser == 'nadam':
        koptimiser = k.optimizers.Nadam(lr = learning_rate)
    else:
        print('Unknown optimiser type')

    model.compile(optimizer = koptimiser, loss = loss, metrics = metrics)

    model.summary()

    return model

#==============================================================

parameters =    {
            'estimator__optimiser': ['adam', 'adamax', 'nadam'],
            'estimator__learning_rate': sp.stats.uniform(0.0005, 0.0015),
            'estimator__epochs': sp.stats.randint(10, 50),
            'estimator__n_neurons': sp.stats.randint(20, 61),
            'estimator__n_layers': sp.stats.randint(1, 3),
            'estimator__n_classes': [2],
            'estimator__batch_size': sp.stats.randint(1, 11),
            'estimator__l1_reg': sp.stats.reciprocal(1e-3, 1e1),
            'estimator__l2_reg': sp.stats.reciprocal(1e-3, 1e1),
            'estimator__dropout': [None],
            'estimator__metrics': [['accuracy']],
            'estimator__loss': ['binary_crossentropy'],
            'estimator__input_shape': [(X_train.shape[1],)]
            }


keras_model = OneVsRestClassifier(KerasClassifier(build_fn = build_keras_model, 
                                                  class_weight='balanced',
                                                  verbose = 1))

clf = RandomizedSearchCV(keras_model, 
                         parameters, 
                         n_iter = 1, 
                         scoring = 'balanced_accuracy', 
                         n_jobs = 1, 
                         cv = 5, 
                         random_state = 100)


clf.fit(X_train, y_train)

model = clf.best_estimator_

Этот код не работает со следующей ошибкой:

/ opt / conda / lib / python3.6 / site-packages / sklearn / base.py в клоне (оценщик, безопасный) 71поднять RuntimeError ('Невозможно клонировать объект% s, поскольку конструктор' 72 'либо не устанавливает, либо изменяет параметр% s'% ---> 73 (оценщик, имя)) 74 return new_object 75

RuntimeError:Не удается клонировать объект, так как конструктор не устанавливает или не изменяет параметр l1_reg

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...