Устойчивость Sklearn Pipeline с пользовательским классом не работает - PullRequest
1 голос
/ 11 апреля 2019

Я использую конвейер sklearn в своем коде и сохраняю объект конвейера для развертывания в другом окружении. У меня есть один пользовательский класс для отбрасывания функций. Я успешно сохраняю модель, но когда я использую объект конвейера в другой среде, которая имеет ту же версию sklearn, он выдает ошибку. Конвейер работает нормально, когда я не включил свой пользовательский класс DropFeatures . Ниже приведен код

from sklearn import svm
from sklearn.feature_selection import SelectKBest
from sklearn.feature_selection import chi2
from sklearn.ensemble import RandomForestClassifier
from sklearn.pipeline import Pipeline
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.externals import joblib
# Load the Iris dataset
df = pd.read_csv('Iris.csv')
label = 'Species'
labels = df[label]
df.drop(['Species'],axis=1,inplace=True)
# Set up a pipeline with a feature selection preprocessor that
# selects the top 2 features to use.
# The pipeline then uses a RandomForestClassifier to train the model.
class DropFeatures(BaseEstimator, TransformerMixin):

    def __init__(self, features_to_drop=None):
        self.features = features_to_drop

    def fit(self, X, y=None):
        return self

    def transform(self, X):
    # encode labels
        if len(self.features) != 0:
            X = X.copy()
            X = X.drop(self.features, axis=1)
            return X
        return X


pipeline = Pipeline([
      ('drop_features', DropFeatures(['Id'])),
      ('feature_selection', SelectKBest(chi2, k=1)),
      ('classification', RandomForestClassifier())
    ])


pipeline.fit(df, labels)
print(pipeline.predict(query))
# Export the classifier to a file
joblib.dump(pipeline, 'model.joblib')

Когда я использую model.joblib в другой среде, я получаю сообщение об ошибке. Ниже приведен код для загрузки модели и ошибка на изображении

from sklearn.externals import joblib
model = joblib.load('model1.joblib')
print(model)

Ошибка трассировки стека: Error stack trace image

...