Я не уверен, что правильно вас понял, но предлагаю следующее решение. KerasClassifier
наследуется от BaseWrapper
, который имеет следующую __init__
подпись:
def __init__(self, build_fn=None, **sk_params):
self.build_fn = build_fn
self.sk_params = sk_params
self.check_params(sk_params)
хорошо, что такое build_fn и sk_params?
build_fn
должен создать, скомпилировать и вернуть модель Keras, которая
затем будет использоваться для подгонки / прогнозирования. Один из следующих
три значения могут быть переданы в build_fn
:
1. Функция
2. Экземпляр класса, который реализует метод __call__
3. Нет. Это означает, что вы реализуете класс, который наследуется от
KerasClassifier
или KerasRegressor
. __call__
метод
Настоящий класс будет обрабатываться по умолчанию build_fn
.
...
sk_params
принимает как параметры модели, так и параметры подгонки. Правовая модель
параметры являются аргументами build_fn
. Обратите внимание, что, как и все другие
Оценщики в Scikit-Learn, build_fn
должны предоставить значения по умолчанию для
его аргументы, чтобы вы могли создать оценщик, не передавая
значения до sk_params
.
...
некоторые символы не указаны
Вы можете прочитать полный комментарий на это и это ссылки.
Поскольку build_fn
ожидает функцию, которая возвращает скомпилированную модель keras (независимо от того, что это - Sequential
или просто Model
) - вы можете передать в качестве значения функцию, которая возвращает загруженную модель.
Редактировать Также вам следует позвонить fit
с некоторыми параметрами, чтобы восстановить модель с использованием этого подхода.
загрузить модель как build_fn
Метод
fit
вызывает build_fn
, следовательно, каждый раз, когда вы пытаетесь обучить такой классификатор, вы загружаете, а затем устанавливаете загруженный классификатор.
Например:
from keras.models import load_model # or another method - but this one is simpliest
from keras.wrappers.scikit_learn import KerasClassifier
def load_model(*args, **kwargs):
"""probably this function expects sk_params, so you can use it in theory"""
path="my_model.hd5"
model = load_model(path)
return model
keras_classifier = KerasClassifier(load_model, sk_params) # use your sk_params
keras_classifier.fit(X_tr, y_tr) # I use slice (1, input_shape) to train
- это будет работать, так как загруженная модель практически обучена и скомпилирована. Но это дает небольшой сдвиг для вашей модели, даже если вы позвоните ей с партией размера 1 и на 1 эпоху.
загрузка через закрытие build_fn
Также вы можете сначала загрузить модель (если вы хотите легко указать путь, а путь к жесткому коду неприемлемо), а затем вернуть функцию, которая "build_fn - Принимаемо" :
def load_model_return_build_fn(path):
model = load_model(path)
def build_fn(*args, **kwars):
"""probably this function expects sk_params"""
return model # defined above
return build_fn
build_fn = load_model_return_build_fn("model.hd5")
keras_classifier = KerasClassifier(build_fn, sk_params) # use your sk_params
keras_classifier.fit(X_tr, y_tr) # I use slice (1, input_shape) to train
назначить модель для ее атрибута
Если вы планируете просто загрузить и использовать предварительно обученную модель, вы можете использовать любую для ее загрузки, присвоить атрибуту model
и не вызывать fit
.
build_fn = load_model_return_build_fn("model.hd5")
# or the function which realy builds and fits a model
keras_classifier = KerasClassifier(build_fn, sk_params) # use your sk_params
keras_classifier.model = model # assign model here, don't call fit
- в этом случае вы явно устанавливаете модель для ее атрибута. Обратите внимание, что build_fn должен быть правильным build_fn - иначе он не пройдет тест self.check_params(sk_params)
.
Наследование от KerasClassifier (не так просто, как я думал)
В конце концов, лучшее из известных мне решений - наследовать от KerasClassifier
и добавить метод load
и / или from_file
.
class KerasClassifierLoadable(KerasClassifier):
@classmethod
def from_file(cls, path, *args, **kwargs):
keras_classifier = cls(*args, **kwargs)
keras_classifier.model = load_model(path)
outp_shape = keras_classifier.model.layers[-1].output_shape[-1]
if outp_shape > 1:
keras_classifier.classes_ = np.arange(outp_shape, dtype='int32')
else:
raise ValueError("Inconsistent output shape: outp_shape={}".format(outp_shape))
keras_classifier.n_classes_ = len(keras_classifier.classes_)
return keras_classifier
def load(self, path):
self.model = load_model(path)
outp_shape = keras_classifier.model.layers[-1].output_shape[-1]
if outp_shape > 1:
keras_classifier.classes_ = np.arange(outp_shape, dtype='int32')
else:
raise ValueError("Inconsistent output shape: outp_shape={}".format(outp_shape))
self.n_classes_ = len(self.classes_)
здесь мы должны установить self.classes_ на правильные метки классов - но я использую только целочисленные значения из `range (0, n_classes).
Использование (build_fn
может быть любым подходящим build_fn
):
keras_classifier = KerasClassifierLoadable.from_file("model.hd5", build_fn=build_fn)
keras_classifier = KerasClassifierLoadable(build_fn=build_fn)
keras_classifier.load("model.hd5")