IIUC, вы можете избежать дорогостоящего объединения и добиться этого, используя один groupBy
.
final_frame_2 = df.groupBy("category").agg(
F.sum(F.col("value")*F.col("flag")).alias("foo1"),
F.sum(F.col("value")).alias("foo2"),
)
final_frame_2.show()
#+--------+----+----+
#|category|foo1|foo2|
#+--------+----+----+
#| B| 0.0|15.0|
#| A|10.0|22.0|
#+--------+----+----+
Теперь сравните планы выполнения:
Сначала ваш метод:
final_frame.explain()
#== Physical Plan ==
#*(5) Project [category#0, foo1#68, foo2#75]
#+- SortMergeJoin [category#0], [category#78], LeftOuter
# :- *(2) Sort [category#0 ASC NULLS FIRST], false, 0
# : +- *(2) HashAggregate(keys=[category#0], functions=[sum(cast(value#1 as double))])
# : +- Exchange hashpartitioning(category#0, 200)
# : +- *(1) HashAggregate(keys=[category#0], functions=[partial_sum(cast(value#1 as double))])
# : +- *(1) Project [category#0, value#1]
# : +- *(1) Filter (isnotnull(flag#2) && (cast(flag#2 as int) = 1))
# : +- Scan ExistingRDD[category#0,value#1,flag#2]
# +- *(4) Sort [category#78 ASC NULLS FIRST], false, 0
# +- *(4) HashAggregate(keys=[category#78], functions=[sum(cast(value#79 as double))])
# +- Exchange hashpartitioning(category#78, 200)
# +- *(3) HashAggregate(keys=[category#78], functions=[partial_sum(cast(value#79 as double))])
# +- *(3) Project [category#78, value#79]
# +- Scan ExistingRDD[category#78,value#79,flag#80]
Теперь то же самое для final_frame_2
:
final_frame_2.explain()
#== Physical Plan ==
#*(2) HashAggregate(keys=[category#0], functions=[sum((cast(value#1 as double) * cast(flag#2 as double))), sum(cast(value#1 as double))])
#+- Exchange hashpartitioning(category#0, 200)
# +- *(1) HashAggregate(keys=[category#0], functions=[partial_sum((cast(value#1 as double) * cast(flag#2 as double))), partial_sum(cast(value#1 as double))])
# +- Scan ExistingRDD[category#0,value#1,flag#2]
Примечание : Строго говоря, это не точный такой же вывод, как в приведенном вами примере(показано ниже), поскольку ваше внутреннее объединение удалит все категории, в которых нет строки с flag = 1
.
final_frame.show()
#+--------+----+----+
#|category|foo1|foo2|
#+--------+----+----+
#| A|10.0|22.0|
#+--------+----+----+
Вы можете добавить агрегацию к сумме flag
и отфильтровать те, где сумма равна нулю, если это требование, с незначительным ударом по производительности.
final_frame_3 = df.groupBy("category").agg(
F.sum(F.col("value")*F.col("flag")).alias("foo1"),
F.sum(F.col("value")).alias("foo2"),
F.sum(F.col("flag")).alias("foo3")
).where(F.col("foo3")!=0).drop("foo3")
final_frame_3.show()
#+--------+----+----+
#|category|foo1|foo2|
#+--------+----+----+
#| A|10.0|22.0|
#+--------+----+----+