Вы можете использовать идеи из Несколько агрегаций , это может сделать все за одну операцию случайного воспроизведения, которая является самой дорогой операцией.
Пример:
val df = (Seq((1, "a", "1"),
(1,"b", "3"),
(1,"c", "6"),
(2, "a", "9"),
(2,"c", "10"),
(1,"b","8" ),
(2, "c", "3"),
(3,"r", "19")).toDF("col1", "col2", "col3"))
df.createOrReplaceTempView("data")
val grpRes = spark.sql("""select grouping_id() as gid, col1, col2, round(mean(col3), 2) as res
from data group by col1, col2 grouping sets ((col1), (col2)) """)
grpRes.show(100, false)
Выход:
+---+----+----+----+
|gid|col1|col2|res |
+---+----+----+----+
|1 |3 |null|19.0|
|2 |null|b |5.5 |
|2 |null|c |6.33|
|1 |1 |null|4.5 |
|2 |null|a |5.0 |
|1 |2 |null|7.33|
|2 |null|r |19.0|
+---+----+----+----+
gid немного забавно использовать, поскольку под ним есть несколько двоичных вычислений. Но если в столбцах вашей группировки не может быть пустых значений, вы можете использовать ее для выбора правильных групп.
План выполнения:
scala> grpRes.explain
== Physical Plan ==
*(2) HashAggregate(keys=[col1#111, col2#112, spark_grouping_id#108], functions=[avg(cast(col3#9 as double))])
+- Exchange hashpartitioning(col1#111, col2#112, spark_grouping_id#108, 200)
+- *(1) HashAggregate(keys=[col1#111, col2#112, spark_grouping_id#108], functions=[partial_avg(cast(col3#9 as double))])
+- *(1) Expand [List(col3#9, col1#109, null, 1), List(col3#9, null, col2#110, 2)], [col3#9, col1#111, col2#112, spark_grouping_id#108]
+- LocalTableScan [col3#9, col1#109, col2#110]
Как вы можете видеть, существует одна операция Exchange, дорогостоящее перемешивание.