PySpark udf возвращает ноль, когда функция работает в кадре данных Pandas - PullRequest
0 голосов
/ 18 октября 2019

Я пытаюсь создать пользовательскую функцию, которая принимает накопленную сумму массива и сравнивает значение с другим столбцом. Вот воспроизводимый пример:

from pyspark.sql.session import SparkSession

# instantiate Spark
spark = SparkSession.builder.getOrCreate()

# make some test data
columns = ['loc', 'id', 'date', 'x', 'y']
vals = [
    ('a', 'b', '2016-07-01', 1, 5),
    ('a', 'b', '2016-07-02', 0, 5),
    ('a', 'b', '2016-07-03', 5, 15),
    ('a', 'b', '2016-07-04', 7, 5),
    ('a', 'b', '2016-07-05', 8, 20),
    ('a', 'b', '2016-07-06', 1, 5)
]

# create DataFrame
temp_sdf = (spark
      .createDataFrame(vals, columns)
      .withColumn('x_ary', collect_list('x').over(Window.partitionBy(['loc','id']).orderBy(desc('date')))))

temp_df = temp_sdf.toPandas()

def test_function(x_ary, y):
  cumsum_array = np.cumsum(x_ary) 
  result = len([x for x in cumsum_array if x <= y])
  return result

test_function_udf = udf(test_function, ArrayType(LongType()))

temp_df['len'] = temp_df.apply(lambda x: test_function(x['x_ary'], x['y']), axis = 1)
display(temp_df)

В Pandas это вывод:

loc id  date        x   y   x_ary           len
a   b   2016-07-06  1   5   [1]             1
a   b   2016-07-05  8   20  [1,8]           2
a   b   2016-07-04  7   5   [1,8,7]         1
a   b   2016-07-03  5   15  [1,8,7,5]       2
a   b   2016-07-02  0   5   [1,8,7,5,0]     1
a   b   2016-07-01  1   5   [1,8,7,5,0,1]   1

В Spark с использованием temp_sdf.withColumn('len', test_function_udf('x_ary', 'y')) все len в конечном итоге составляют null.

Кто-нибудь знает, почему это так?

Кроме того, замена cumsum_array = np.cumsum(np.flip(x_ary)) завершается неудачно в pySpark с ошибкой AttributeError: module 'numpy' has no attribute 'flip', но я знаю, что она существует, поскольку я могу нормально работать с Pandasdataframe.
Можно ли решить эту проблему или есть лучший способ перевернуть массивы с помощью pySpark?

Заранее благодарим за помощь.

1 Ответ

1 голос
/ 18 октября 2019

Удалите «ArrayType from udf», тогда он будет работать, как показано ниже:

test_function_udf = udf(test_function)
temp_sdf = temp_sdf.withColumn('len', 
           test_function_udf('x_ary', 'y'))
temp_sdf.show()
...