Spark собирать ограниченный отсортированный список - PullRequest
1 голос
/ 26 января 2020

Я пытаюсь использовать spark для создания ограниченного отсортированного списка для фрейма данных, однако я не могу думать о быстром и низком объеме памяти.

Мой фрейм данных состоит из трех столбцов и двух идентификаторов ключей и столбец расстояния, и я хочу получить список лучших n = 50 идентификаторов, близких к каждому из идентификаторов. Я попробовал groupBy, а затем collect_list, затем sort_array, а затем UDF, чтобы получить только идентификаторы и, наконец, передать его через UDF, чтобы получить первые n = 50, но это очень медленно и иногда вызывает ошибку памяти.

# Sample Data
val dataFrameTest = Seq(
      ("key1", "key2", 1),
      ("key1","key3", 2),
      ("key1", "key5" ,4),
      ("key1", "key6" ,5),
      ("key1","key8" ,6),
      ("key2", "key7" ,3),
      ("key2", "key9" ,4),
      ("key2","key5" ,5)
      ).toDF("id1", "id2", "distance")

Если ограничение равно 2, то нужно

"key1" | ["key2", "key3"]    
"key2" | ["key7", "key8"]

current_approach:

sorted_df = dataFrameTest.groupBy("key1").agg(collect_list(struct("distance", "id2")).alias("toBeSortedCol")).
withColumn("sortedList", sort_array("toBeSortedCol"))

Мои данные достаточно велики, поэтому единственное решение - искра. Я ценю любую помощь / руководство.

1 Ответ

1 голос
/ 26 января 2020

Как насчет использования для этого одной из оконных функций Spark SQL? Что-то вроде

scala> val dataFrameTest = Seq(
     |       ("key1", "key2", 1),
     |       ("key1","key3", 2),
     |       ("key1", "key5" ,4),
     |       ("key1", "key6" ,5),
     |       ("key1","key8" ,6),
     |       ("key2", "key7" ,3),
     |       ("key2", "key9" ,4),
     |       ("key2","key5" ,5)
     |       ).toDF("id1", "id2", "distance")
dataFrameTest: org.apache.spark.sql.DataFrame = [id1: string, id2: string ... 1 more field]

scala> dataFrameTest.createOrReplaceTempView("sampledata")

scala> spark.sql("""
     | select t.id1, collect_list(t.id2) from (
     | select id1, id2, row_number() over (partition by id1 order by distance) as rownum from sampledata
     | )t
     | where t.rownum < 3 group by t.id1
     | """).show(false)
+----+-----------------+
|id1 |collect_list(id2)|
+----+-----------------+
|key1|[key2, key3]     |
|key2|[key7, key9]     |
+----+-----------------+

scala>

Просто замените row_number() на rank() или dense_rank() в зависимости от типа нужного вам результата.

...