получить параметры от встроенного настраиваемого трансформатора pyspark - PullRequest
0 голосов
/ 06 марта 2019

Предположим, следующий пользовательский преобразователь Pyspark:

class CustomTransformer(MockTransformer, Identifiable, PysparkReaderWriter, MLReadable, MLWritable):

    def __init__(self, output_col):
        self.output_col = output_col
        self.feat_cols = None
        super(CustomTransformer, self).__init__()

    def _transform(self, df):

        self.feat_cols = get_match_columns(df, "ops")
        # Do something smart here with this feat_cols
        df = df.drop(*self.feat_cols)

        return df

, где feat_cols вычисляется и устанавливается в методе _transform(), а get_match_columns - это функция, которая возвращает имена столбцов, которые соответствуют некоторому шаблону.,Мне нужно получить доступ к этому параметру после преобразования конвейера, содержащего этот преобразователь, например:

pipeline = Pipeline(stages=[custom_transformer, assembler])
myPipe = pipeline.fit(data)
result = myPipe.transform(data)

с помощью какого-либо метода, например:

result.stages[0].getParam('feat_cols')

, но, очевидно, это не такРабота.Я пытался следовать этому упаковщику , кодируя этот геттер в моем преобразователе:

def getFeatCols(self):
        return self.getOrDefault(self.feat_cols)

, но я все еще не могу восстановить параметр (либо result.stages[0]._java_obj.getParam('feat_cols') работает).

Есть ли способ решить эту проблему в Pyspark?

1 Ответ

0 голосов
/ 07 марта 2019

Как отметил @ user10938362 в комментарии, необходимо использовать Param.В данном конкретном случае код, который мне подходит:

from pyspark.ml.param import Param

class CustomTransformer(MockTransformer, Identifiable, PysparkReaderWriter, MLReadable, MLWritable):

    def __init__(self, output_col):
        super(CustomTransformer, self).__init__()
        self.output_col = output_col
        self.feat_cols = Param(self, "feat_cols", "Feature columns")
        self._set(feat_cols=[]) # set or _set depends on the Spark version


    def _transform(self, df):
        self._set(feat_cols=get_match_columns(df, "ops"))
        # Do something smart here with this feat_cols
        df = df.drop(*self.getFeatCols())

        return df

    def getFeatCols(self):

        return self.getOrDefault("feat_cols")
...