Это будет работать в scala. код pyspark должен быть очень похожим.
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window
val df = List(
("yes", 10),
("yes", 30),
("No", 40)).toDF("private", "rate")
val df = l.toDF(List("private", "rate"))
val window =Window.partitionBy($"private")
df.
withColumn("avg",
when($"private" === "No", null).
otherwise(avg($"rate").over(window))
).
show()
Входной DF
+-------+----+
|private|rate|
+-------+----+
| yes| 10|
| yes| 30|
| No| 40|
+-------+----+
Выходной DF
+-------+----+----+
|private|rate| avg|
+-------+----+----+
| No| 40|null|
| yes| 10|20.0|
| yes| 30|20.0|
+-------+----+----+