Поскольку ваш код не зависит от переменной группировки, вы должны полностью удалить groupBy
и использовать скалярный UDF вместо сгруппированной карты.
Таким образом, вам не понадобится перемешивание, и вы сможете использовать локальность данных и доступные ресурсы.
Вам нужно будет переопределить ваши функции, чтобы взять все необходимые столбцы и вернуть pandas.Series
:
def prediction_func(*cols: pandas.Series) -> pandas.Series:
... # Combine cols into a single pandas.DataFrame and apply the model
return ... # Convert result to pandas.Series and return
Пример использования:
from pyspark.sql.functions import PandasUDFType, pandas_udf, rand
import pandas as pd
import numpy as np
df = spark.range(100).select(rand(1), rand(2), rand(3)).toDF("x", "y", "z")
@pandas_udf("double", PandasUDFType.SCALAR)
def dummy_prediction_function(x, y, z):
pdf = pd.DataFrame({"x": x, "y": y, "z": z})
pdf["prediction"] = 1.0
return pdf["prediction"]
df.withColumn("prediction", dummy_prediction_function("x", "y", "z")).show(3)
+-------------------+-------------------+--------------------+----------+
| x| y| z|prediction|
+-------------------+-------------------+--------------------+----------+
|0.13385709732307427| 0.2630967864682161| 0.11641995793557336| 1.0|
| 0.5897562959687032|0.19795734254405561| 0.605595773295928| 1.0|
|0.01540012100242305|0.25419718814653214|0.006007018601722036| 1.0|
+-------------------+-------------------+--------------------+----------+
only showing top 3 rows