Невозможно использовать стандартные возможности sklearn
, но вы можете создать класс, выполняющий то, что вы хотите легко (и я бы согласился, что он более читабелен).
Я бы лично пошел на что-то подобное:
class SlicingClassifier:
"""Create classifiers of slice data.
Parameters
----------
classifier : sklearn-compatible classifier
Classifier with sklearn-compatible interface. Needs fit method.
slices : array-like[bool] or generator
List of Boolean numpy arrays able to slice features and targets.
Can be a generator returning slices.
"""
def __init__(self, classifier, slices):
self.classifier_ = classifier
self.slices_ = slices
self.models_ = []
def fit(self, X, y):
for index_slice in self.slices_:
self.models_.append(self.classifier_.fit(X[index_slice], y[index_slice]))
return self
# You may want to make this a list, but it's more memory-efficient as gen
def predict(self, X):
for model in self.models_:
yield model.predict(X)
При желании вы можете легко расширить этот подход, используя несколько классификаторов, другой метод predict
, fit_transform
, что делает API-интерфейс совместимым с sklearn
и т. Д.
Хорошим дополнением (с точки зрения памяти) может быть генератор fit_transform
подобная функция, если вы заботитесь только о предсказаниях для каждого подмножества данных:
def fit_transform_generator(self, X, y):
for index_slice in self.slices_:
yield self.classifier_.fit_transform(X[index_slice], y[index_slice])
И примерные вызовы будут идти по этим линиям и спасут вас отнекрасивые временные создания нарезанных массивов.массивы, которые нарезают исходную X
).
Срезы как генератор
Вместо передачи срезов как List
вы также можете использовать объект генератораl (для многих индексов срезов вы должны использовать этот подход).
Пример, приведенный для устранения путаницы:
def slices_generator(X, stop, start=0, step=1):
for i in range(start, stop, step):
yield X < i
classifier = SlicingClassifier(
GradientBoostingClassifier(
n_estimators=100, learning_rate=1.0, max_depth=5, random_state=1
),
slices=slices_generator(X, 1000),
)