Pyspark UDF для сравнения разреженных векторов - PullRequest
0 голосов
/ 12 марта 2019

Я пытаюсь написать UDF для pyspark, который будет сравнивать два Sparse Vector для меня.Я хотел бы написать следующее:

from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, IntegerType, FloatType

def compare(req_values, values):
    return [req for req in req_values.indices if req not in values.indices]

compare_udf = udf(compare, ArrayType(IntegerType()))

display(data.limit(5).select('*', compare_udf('req_values', 'values').alias('missing')))

Однако, когда я запускаю этот код, я получаю следующее сообщение об ошибке:

SparkException: Job aborted due to stage failure: Task 0 in stage 129.0 failed 4 times, most recent failure: Lost task 0.3 in stage 129.0 (TID 1256, 10.139.64.15, executor 2): net.razorvine.pickle.PickleException: expected zero arguments for construction of ClassDict (for numpy.dtype)

Я столкнулся с подобными проблемами, с которыми до этогоотносятся к типу dataframe, который не может справиться с несколькими типами данных.Раньше мне удавалось решить эти проблемы путем приведения массива numpy в список перед его возвратом, но в этом случае кажется, что я не могу даже извлечь данные из SparseVector, например, даже следующее не работает:

def compare(req_values, values):
    return req_values.indices[0]   

compare_udf = udf(compare, IntegerType())

Мне удалось обойти проблемы, используя RDD, но я все еще нахожу это разочаровывающим ограничением с UDF pyspark.Любой совет или помощь приветствуются!

1 Ответ

0 голосов
/ 12 марта 2019

Кажется, я решил эту проблему сам - проблема сводится к тому, что составляющие компоненты Sparse Vector из mllib являются типами numpy, которые сами по себе не поддерживаются pispark DataFrame. Работает следующий настроенный код:

from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, IntegerType, FloatType

def compare(req_values, values):
    return [int(req) for req in req_values.indices if req not in values.indices]

compare_udf = udf(compare, ArrayType(IntegerType()))

display(data.limit(5).select('*', compare_udf('req_values', 'values').alias('missing')))

Хотя это работает, мне кажется несколько нелогичным, что pyspark DataFrame будет поддерживать сконструированный тип данных (SparseVector), но не сам по себе составные части (целые числа), и не будет предоставлять более информативное сообщение об ошибке, объясняющее проблему.

...