группируйте данные в pyspark и получайте данные topn в каждой группе - PullRequest
0 голосов
/ 19 октября 2019

У меня есть данные, которые могут быть просто показаны как:

conf = SparkConf().setMaster("local[*]").setAppName("test")
sc = SparkContext(conf=conf).getOrCreate()
spark = SparkSession(sparkContext=sc).builder.getOrCreate()

rdd = sc.parallelize([(1, 10), (3, 11), (1, 8), (1, 12), (3, 7), (3, 9)])
data = spark.createDataFrame(rdd, ['x', 'y'])
data.show()

def f(x):
    y = sorted(x, reverse=True)[:2]
    return y

h_f = udf(f, IntegerType())
h_f = spark.udf.register("h_f", h_f)
data.groupBy('x').agg({"y": h_f}).show()

Но все пошло не так: AttributeError: у объекта 'function' нет атрибута '_get_object_id', как я могу получить элемент topn в каждомгруппа

1 Ответ

1 голос
/ 19 октября 2019

Учитывая, что вы ищете верхние n 'y' элементов, которые принадлежат каждой группе 'x'.

from pyspark.sql import Window
from pyspark.sql import functions as F
import sys

rdd = sc.parallelize([(1, 10), (3, 11), (1, 8), (1, 12), (3, 7), (3, 9)])
df = spark.createDataFrame(rdd, ['x', 'y'])
df.show()

df_g = df.groupBy('x').agg(F.collect_list('y').alias('y'))
df_g = df_g.withColumn('y_sorted', F.sort_array('y', asc = False))
df_g.withColumn('y_slice', F.slice(df_g.y_sorted, 1, 2)).show()

Вывод

+---+-----------+-----------+--------+
|  x|          y|   y_sorted| y_slice|
+---+-----------+-----------+--------+
|  1|[10, 8, 12]|[12, 10, 8]|[12, 10]|
|  3| [11, 7, 9]| [11, 9, 7]| [11, 9]|
+---+-----------+-----------+--------+
...