Pyspark - групповой с фильтром - оптимизация скорости - PullRequest
1 голос
/ 06 ноября 2019

У меня есть миллиарды строк для обработки с использованием Pyspark.

Датафрейм выглядит следующим образом:

category    value    flag
   A          10       1
   A          12       0
   B          15       0
and so on...

Мне нужно выполнить две групповые операции: одну для строк, где флаг == 1, иДругое для ВСЕХ рядов. В настоящее время я делаю это:

frame_1 = df.filter(df.flag==1).groupBy('category').agg(F.sum('value').alias('foo1'))
frame_2 = df.groupBy('category').agg(F.sum('value').alias(foo2))
final_frame = frame1.join(frame2,on='category',how='left')

На данный момент этот код выполняется, но моя проблема в том, что он очень медленный. Есть ли способ улучшить этот код с точки зрения скорости, или это предел, потому что я понимаю, что отложенная оценка PySpark действительно занимает некоторое время, но является ли этот код лучшим способом сделать это?

Ответы [ 2 ]

1 голос
/ 06 ноября 2019

IIUC, вы можете избежать дорогостоящего объединения и добиться этого, используя один groupBy.

final_frame_2 = df.groupBy("category").agg(
    F.sum(F.col("value")*F.col("flag")).alias("foo1"),
    F.sum(F.col("value")).alias("foo2"),
)
final_frame_2.show()
#+--------+----+----+
#|category|foo1|foo2|
#+--------+----+----+
#|       B| 0.0|15.0|
#|       A|10.0|22.0|
#+--------+----+----+

Теперь сравните планы выполнения:

Сначала ваш метод:

final_frame.explain()
#== Physical Plan ==
#*(5) Project [category#0, foo1#68, foo2#75]
#+- SortMergeJoin [category#0], [category#78], LeftOuter
#   :- *(2) Sort [category#0 ASC NULLS FIRST], false, 0
#   :  +- *(2) HashAggregate(keys=[category#0], functions=[sum(cast(value#1 as double))])
#   :     +- Exchange hashpartitioning(category#0, 200)
#   :        +- *(1) HashAggregate(keys=[category#0], functions=[partial_sum(cast(value#1 as double))])
#   :           +- *(1) Project [category#0, value#1]
#   :              +- *(1) Filter (isnotnull(flag#2) && (cast(flag#2 as int) = 1))
#   :                 +- Scan ExistingRDD[category#0,value#1,flag#2]
#   +- *(4) Sort [category#78 ASC NULLS FIRST], false, 0
#      +- *(4) HashAggregate(keys=[category#78], functions=[sum(cast(value#79 as double))])
#         +- Exchange hashpartitioning(category#78, 200)
#            +- *(3) HashAggregate(keys=[category#78], functions=[partial_sum(cast(value#79 as double))])
#               +- *(3) Project [category#78, value#79]
#                  +- Scan ExistingRDD[category#78,value#79,flag#80]

Теперь то же самое для final_frame_2:

final_frame_2.explain()
#== Physical Plan ==
#*(2) HashAggregate(keys=[category#0], functions=[sum((cast(value#1 as double) * cast(flag#2 as double))), sum(cast(value#1 as double))])
#+- Exchange hashpartitioning(category#0, 200)
#   +- *(1) HashAggregate(keys=[category#0], functions=[partial_sum((cast(value#1 as double) * cast(flag#2 as double))), partial_sum(cast(value#1 as double))])
#      +- Scan ExistingRDD[category#0,value#1,flag#2]

Примечание : Строго говоря, это не точный такой же вывод, как в приведенном вами примере(показано ниже), поскольку ваше внутреннее объединение удалит все категории, в которых нет строки с flag = 1.

final_frame.show()
#+--------+----+----+
#|category|foo1|foo2|
#+--------+----+----+
#|       A|10.0|22.0|
#+--------+----+----+

Вы можете добавить агрегацию к сумме flag и отфильтровать те, где сумма равна нулю, если это требование, с незначительным ударом по производительности.

final_frame_3 = df.groupBy("category").agg(
    F.sum(F.col("value")*F.col("flag")).alias("foo1"),
    F.sum(F.col("value")).alias("foo2"),
    F.sum(F.col("flag")).alias("foo3")
).where(F.col("foo3")!=0).drop("foo3")

final_frame_3.show()
#+--------+----+----+
#|category|foo1|foo2|
#+--------+----+----+
#|       A|10.0|22.0|
#+--------+----+----+
0 голосов
/ 06 ноября 2019

Обратите внимание, что операция объединения стоит дорого. Вы можете просто сделать это и добавить флаг в свои группы:

frame_1 = df.groupBy(["category", "flag"]).agg(F.sum('value').alias('foo1'))

, если у вас более двух флагов, и вы хотите сделать flag == 1 vs the rest затем:

import pyspark.sql.functions as F
frame_1 = df.withColumn("flag2", F.when(F.col("flag") == 1, 1).otherwise(0))
frame_1 = df.groupBy(["category", "flag2"]).agg(F.sum('value').alias('foo1'))

, если хотитечтобы применить групповую обработку для всех строк, просто создайте новый фрейм, в котором вы еще раз свернете для категории:

frame_1 = df.groupBy("category").agg(F.sum('foo1').alias('foo2'))

, невозможно выполнить оба действия за один шаг, потому чтопо существу, существует групповое перекрытие.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...