Я попытаюсь реализовать функцию UDFs pyspark в моем фрейме данных pyspark. Функция должна применяться на основе ключевого столбца, называемого esn, и в результате это массив значений с плавающей запятой.
Чтобы гарантировать, что функция будет работать с каждым esn, я использую функцию Window, как мы можем смотри здесь:
Dataframe:
+-------------------+-------------------+---+
| data| date|esn|
+-------------------+-------------------+---+
| -9.195395665503895|2020-02-21 00:00:00| a|
| -8.62902800042645|2020-02-22 00:00:00| a|
| -8.312878310906097|2020-02-23 00:00:00| a|
| -9.385124458386779|2020-02-24 00:00:00| a|
|-11.020214171187199|2020-02-25 00:00:00| a|
|-11.409155421735703|2020-02-26 00:00:00| a|
|-10.836121514743244|2020-02-27 00:00:00| a|
|-11.170291456965106|2020-02-28 00:00:00| a|
| 15.045650332290494|2020-02-21 00:00:00| b|
| 14.398770139184752|2020-02-22 00:00:00| b|
| 14.881674565121823|2020-02-23 00:00:00| b|
| 14.531230965547016|2020-02-24 00:00:00| b|
|-1.5399506984502245|2020-02-25 00:00:00| b|
|-1.8350032777094971|2020-02-26 00:00:00| b|
|-1.8318166902926658|2020-02-27 00:00:00| b|
|-1.3098303461350946|2020-02-28 00:00:00| b|
+-------------------+-------------------+---+
from pyspark.sql.functions import pandas_udf, PandasUDFType, lit
from pyspark.sql import Window
from pyspark.sql.types import *
import numpy as np
@pandas_udf("double", PandasUDFType.GROUPED_AGG)
def cp_udf_test(data):
R, maxes = oncd.online_changepoint_detection(data, partial(oncd.constant_hazard, 250), oncd.StudentT(0.1, .01, 1, 0))
cp_probs = np.array(R[Nw,Nw:-1][1:])
gap_output = len(data) - len(cp_probs)
#the below code ensures that the output will be the same size of input
result=np.add(np.zeros(gap_output),cp_probs)
return result
#the window
w = Window.partitionBy('esn').orderBy('date').rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
new_df = df.withColumn('probs', cp_udf_test(df['data']).over(w))