Spark - группируйте и объединяйте только несколько самых маленьких предметов - PullRequest
1 голос
/ 27 июня 2019

Короче

У меня есть декартово произведение (кросс-соединение) двух фреймов данных и функция, которая дает некоторую оценку для данного элемента этого продукта. Теперь я хочу получить несколько «наиболее подходящих» элементов второго DF для каждого члена первого DF.

Подробнее

Ниже приведен упрощенный пример, поскольку мой настоящий код несколько раздут с дополнительными полями и фильтрами.

Учитывая два набора данных, каждый из которых имеет некоторый идентификатор и значение:

// simple rdds of tuples
val rdd1 = sc.parallelize(Seq(("a", 31),("b", 41),("c", 59),("d", 26),("e",53),("f",58)))
val rdd2 = sc.parallelize(Seq(("z", 16),("y", 18),("x",3),("w",39),("v",98), ("u", 88)))

// convert them to dataframes:
val df1 = spark.createDataFrame(rdd1).toDF("id1", "val1")
val df2 = spark.createDataFrame(rdd2).toDF("id2", "val2")

и некоторая функция, которая для пары элементов из первого и второго набора данных дает их «соответствующий счет»:

def f(a:Int, b:Int):Int = (a * a + b * b * b) % 17
// convert it to udf
val fu = udf((a:Int, b:Int) => f(a, b))

мы можем создать произведение двух наборов и рассчитать оценку для каждой пары:

val dfc = df1.crossJoin(df2)
val r = dfc.withColumn("rez", fu(col("val1"), col("val2")))
r.show

+---+----+---+----+---+
|id1|val1|id2|val2|rez|
+---+----+---+----+---+
|  a|  31|  z|  16|  8|
|  a|  31|  y|  18| 10|
|  a|  31|  x|   3|  2|
|  a|  31|  w|  39| 15|
|  a|  31|  v|  98| 13|
|  a|  31|  u|  88|  2|
|  b|  41|  z|  16| 14|
|  c|  59|  z|  16| 12|
...

И теперь мы хотим сгруппировать этот результат по id1:

r.groupBy("id1").agg(collect_set(struct("id2", "rez")).as("matches")).show

+---+--------------------+
|id1|             matches|
+---+--------------------+
|  f|[[v,2], [u,8], [y...|
|  e|[[y,5], [z,3], [x...|
|  d|[[w,2], [x,6], [v...|
|  c|[[w,2], [x,6], [v...|
|  b|[[v,2], [u,8], [y...|
|  a|[[x,2], [y,10], [...|
+---+--------------------+

Но на самом деле мы хотим сохранить только несколько (скажем, 3) «матчей», имеющих лучший результат (скажем, наименьший).

Вопрос

  1. Как отсортировать "спички" и сократить их до топ-N элементов? Вероятно, это что-то вроде collect_list и sort_array, хотя я не знаю, как сортировать по внутреннему полю.

  2. Есть ли способ обеспечить оптимизацию в случае больших входных ДФ - например, выбирая минимумы напрямую при агрегировании. Я знаю, что это было бы легко сделать, если бы я писал код без искры - сохраняя небольшой массив или очередь с приоритетами для каждого id1 и добавляя элемент там, где он должен быть, возможно, исключая некоторые ранее добавленные.

например. Это нормально, что перекрестное объединение является дорогостоящей операцией, но я хочу не тратить память на результаты, большинство из которых я собираюсь опустить на следующем шаге. Мой реальный сценарий использования касается DF с менее чем 1 млн. Записей, поэтому перекрестное объединение все же жизнеспособно, но, поскольку мы хотим выбрать только 10-20 лучших совпадений для каждого id1, представляется весьма желательным не хранить ненужные данные между этапами. .

1 Ответ

1 голос
/ 27 июня 2019

Для начала нам нужно взять только первые n строк.Для этого мы разделяем DF по 'id1' и сортируем группы по res.Мы используем его для добавления столбца с номером строки в DF, например, мы можем использовать функцию , где , чтобы взять первые n строк.Чем вы можете продолжать делать тот же код, который вы написали.Группировка по 'id1' и сбор списка.Только теперь у вас уже есть самые высокие строки.

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._

val n = 3
val w = Window.partitionBy($"id1").orderBy($"res".desc)
val res = r.withColumn("rn", row_number.over(w)).where($"rn" <= n).groupBy("id1").agg(collect_set(struct("id2", "res")).as("matches"))

Второй вариант, который может быть лучше, потому что вам не нужно группировать DF дважды:

val sortTakeUDF = udf{(xs: Seq[Row], n: Int)} => xs.sortBy(_.getAs[Int]("res")).reverse.take(n).map{case Row(x: String, y:Int)}}
r.groupBy("id1").agg(sortTakeUDF(collect_set(struct("id2", "res")), lit(n)).as("matches"))

Здесь мы создаемudf, который принимает столбец массива и целочисленное значение n.Udf сортирует массив по вашему 'res' и возвращает только первые n элементов.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...