У меня есть Dataframe вида:
+---+---+----+
| A| B|dist|
+---+---+----+
| a1| b1| 1.0|
| a1| b2| 2.0|
| a2| b1|10.0|
| a2| b2|10.0|
| a2| b3| 2.0|
| a3| b1|10.0|
+---+---+----+
и, фиксированный max_rank = 2, я хочу получить следующий
+---+---+----+----+
| A| B|dist|rank|
+---+---+----+----+
| a3| b1|10.0| 1|
| a2| b3| 2.0| 1|
| a2| b1|10.0| 2|
| a2| b2|10.0| 2|
| a1| b1| 1.0| 1|
| a1| b2| 2.0| 2|
+---+---+----+----+
Классический метод, который делает этоследующее
df = sqlContext.createDataFrame([("a1", "b1", 1.), ("a1", "b2", 2.), ("a2", "b1", 10.), ("a2", "b2", 10.), ("a2", "b3", 2.), ("a3", "b1", 10.)], schema=StructType([StructField("A", StringType(), True), StructField("B", StringType(), True),StructField("dist", FloatType(), True)]))
win = Window().partitionBy(df['A']).orderBy(df['dist'])
out = df.withColumn('rank', rank().over(win))
out = out.filter('rank<=2')
Однако это решение неэффективно из-за функции Window, которая использует OrderBy.
Есть ли другое решение для Pyspark? Например, метод, аналогичный .top (k, key = -) для СДР?
Я нашел аналогичный ответ здесь , но вместо python используется scala.