Если вы знаете elements
, которое вам нужно посчитать, вы можете использовать его с spark2.4+.
, и это будет очень быстро. (Используя higher order function filter
и structs
)
df.show()
#+------------+
#| atr_list|
#+------------+
#|[a, b, b, c]|
#| [b, c, d]|
#+------------+
elements=['a','b','c','d']
from pyspark.sql import functions as F
collected=df.withColumn("struct", F.struct(*[(F.struct(F.expr("size(filter(atr_list,x->x={}))"\
.format("'"+y+"'"))).alias(y)) for y in elements]))\
.select(*[F.sum(F.col("struct.{}.col1".format(x))).alias(x) for x in elements])\
.collect()[0]
{elements[i]: [x for x in collected][i] for i in range(len(elements))}
Out: {'a': 1, 'b': 3, 'c': 2, 'd': 1}
2-й метод с использованием transform, aggregate, explode and groupby
(не требуется указать элементы):
from pyspark.sql import functions as F
a=df.withColumn("atr", F.expr("""transform(array_distinct(atr_list),x->aggregate(atr_list,0,(acc,y)->\
IF(y=x, acc+1,acc)))"""))\
.withColumn("zip", F.explode(F.arrays_zip(F.array_distinct("atr_list"),("atr"))))\
.select("zip.*").withColumnRenamed("0","elements")\
.groupBy("elements").agg(F.sum("atr").alias("sum"))\
.collect()
{a[i][0]: a[i][1] for i in range(len(a))}