проблема для реализации функции UDF с Window (Pyspark) - PullRequest
0 голосов
/ 02 марта 2020

Я попытаюсь реализовать функцию UDFs pyspark в моем фрейме данных pyspark. Функция должна применяться на основе ключевого столбца, называемого esn, и в результате это массив значений с плавающей запятой.

Чтобы гарантировать, что функция будет работать с каждым esn, я использую функцию Window, как мы можем смотри здесь:

Dataframe: 
+-------------------+-------------------+---+
|               data|               date|esn|
+-------------------+-------------------+---+
| -9.195395665503895|2020-02-21 00:00:00|  a|
|  -8.62902800042645|2020-02-22 00:00:00|  a|
| -8.312878310906097|2020-02-23 00:00:00|  a|
| -9.385124458386779|2020-02-24 00:00:00|  a|
|-11.020214171187199|2020-02-25 00:00:00|  a|
|-11.409155421735703|2020-02-26 00:00:00|  a|
|-10.836121514743244|2020-02-27 00:00:00|  a|
|-11.170291456965106|2020-02-28 00:00:00|  a|
| 15.045650332290494|2020-02-21 00:00:00|  b|
| 14.398770139184752|2020-02-22 00:00:00|  b|
| 14.881674565121823|2020-02-23 00:00:00|  b|
| 14.531230965547016|2020-02-24 00:00:00|  b|
|-1.5399506984502245|2020-02-25 00:00:00|  b|
|-1.8350032777094971|2020-02-26 00:00:00|  b|
|-1.8318166902926658|2020-02-27 00:00:00|  b|
|-1.3098303461350946|2020-02-28 00:00:00|  b|
+-------------------+-------------------+---+
from pyspark.sql.functions import pandas_udf, PandasUDFType, lit
from pyspark.sql import Window
from pyspark.sql.types import *
import numpy as np

@pandas_udf("double", PandasUDFType.GROUPED_AGG)
def cp_udf_test(data):

  R, maxes = oncd.online_changepoint_detection(data, partial(oncd.constant_hazard, 250), oncd.StudentT(0.1, .01, 1, 0))

  cp_probs = np.array(R[Nw,Nw:-1][1:])

  gap_output = len(data) - len(cp_probs)
  #the below code ensures that the output will be the same size of input
  result=np.add(np.zeros(gap_output),cp_probs)

  return result

#the window

w = Window.partitionBy('esn').orderBy('date').rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)

new_df = df.withColumn('probs', cp_udf_test(df['data']).over(w))


...