Я думаю, что это не проблема Spark, а логическая.
Вы должны рассмотреть возможность использования нескольких автономных функций:
- чтобы связать два интервала (назовем это
bindEntries()
)
- добавить новый интервал в интервальный аккумулятор интервалов (пусть будет
insertEntry()
)
Предположим, у нас есть фиктивные данные mockData
:
+---+-----+---+
| id|start|end|
+---+-----+---+
| 1| 22| 33|
| 1| 15| 20|
| 1| 10| 15|
| 2| 13| 16|
| 2| 10| 13|
+---+-----+---+
С помощью этих функций мое решение вашей проблемы будет таким:
val processed = mockData
.groupByKey(_.id)
.flatMapGroups { (id: Int, it: Iterator[Entry]) =>
processEntries(it)
}
Единственная цель processEntries()
- сложить все записи для каждого идентификатора в набор непересекающихся интервалов.
Вот его подпись:
def processEntries(it: Iterator[Entry]): List[Entry] =
it.foldLeft(Nil: List[Entry])(insertEntry)
Эта функция используется для получения элементов из ваших сгруппированных записей по одному и вставки их в аккумулятор, также по одному.
Функция insertEntry()
Обработка такого рода вставки:
def insertEntry(acc: List[Entry], e: Entry): List[Entry] = acc match {
case Nil => e :: Nil
case a :: as =>
val combined = bindEntries(a, e)
combined match {
case x :: y :: Nil => x :: insertEntry(as, y)
case x :: Nil => insertEntry(as, x)
case _ => a :: as
}
}
Функция bindEntries()
должна обрабатывать порядок записей для вас:
def bindEntries(x: Entry, y: Entry): List[Entry] =
(x.start > y.end, x.end < y.start) match {
case (true, _) => y :: x :: Nil
case (_, true) => x :: y :: Nil
case _ => x.copy(start = x.start min y.start, end = x.end max y.end) :: Nil
}
bindEntries()
вернет список из одной или двух записей , правильно отсортированных .
Вот идея, стоящая за этим:
insertEntry()
отсортирует все записи для вас во время вставки.
В конце концов, результирующий набор данных выглядит так:
+---+-----+---+
| id|start|end|
+---+-----+---+
| 1| 10| 20|
| 1| 22| 33|
| 2| 10| 16|
+---+-----+---+
Примечание: Функция insertEntry()
не является хвостовой рекурсивной.
Хорошая отправная точка для дальнейшей оптимизации.
И есть полное решение:
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.{Dataset, Encoder, Encoders, SparkSession}
object AdHoc {
Logger.getLogger("org").setLevel(Level.OFF)
def main(args: Array[String]): Unit = {
import spark.implicits._
val processed = mockData
.groupByKey(_.id)
.flatMapGroups { (id, it) =>
processEntries(it)
}
mockData.show()
processed.show()
}
def processEntries(it: Iterator[Entry]): List[Entry] =
it.foldLeft(Nil: List[Entry])(insertEntry)
def insertEntry(acc: List[Entry], e: Entry): List[Entry] = acc match {
case Nil => e :: Nil
case a :: as =>
val combined = bindEntries(a, e)
combined match {
case x :: y :: Nil => x :: insertEntry(as, y)
case x :: Nil => insertEntry(as, x)
case _ => a :: as
}
}
def bindEntries(x: Entry, y: Entry): List[Entry] =
(x.start > y.end, x.end < y.start) match {
case (true, _) => y :: x :: Nil
case (_, true) => x :: y :: Nil
case _ => x.copy(start = x.start min y.start, end = x.end max y.end) :: Nil
}
lazy val mockData: Dataset[Entry] = spark.createDataset(Seq(
Entry(1, 22, 33),
Entry(1, 15, 20),
Entry(1, 10, 15),
Entry(2, 13, 16),
Entry(2, 10, 13)
))
case class Entry(id: Int, start: Int, end: Int)
implicit lazy val entryEncoder: Encoder[Entry] = Encoders.product[Entry]
lazy val spark: SparkSession = SparkSession.builder()
.master("local")
.getOrCreate()
}