У меня есть набор данных, где я могу определить чередующуюся последовательность. Однако я хочу сгруппировать эти данные в один фрагмент, оставив все остальные данные без изменений. То есть там, где когда-либо происходит мерцание в id, я хочу перезаписать эту группу идентификаторов с единственным идентификатором, который делает с тех пор в порядок. В качестве небольшого примера рассмотрим
val dataDF = Seq(
("a", "silom", 3, 1),
("a", "silom", 2, 2),
("a", "silom", 1, 3),
("a", "silom", 0, 4), // flickering; id=0
("a", "silom", 1, 5), // flickering; id=0
("a", "silom", 0, 6), // flickering; id=0
("a", "silom", 1, 7),
("a", "silom", 2, 8),
("a", "silom", 3, 9),
("a", "silom", 4, 10),
("a", "silom", 3, 11), // flickering and so on
("a", "silom", 4, 12),
("a", "silom", 3, 13),
("a", "silom", 4, 14),
("a", "silom", 5, 15)
).toDF("user", "cat", "id", "time_sec")
val resultDataDF = Seq(
("a", "silom", 3, 1),
("a", "silom", 2, 2),
("a", "silom", 1, 3),
("a", "silom", 0, 15), // grouped by flickering summing on time_sec
("a", "silom", 1, 7),
("a", "silom", 2, 8),
("a", "silom", 3, 9),
("a", "silom", 4, 60),
("a", "silom", 5, 15). // grouped by flickering summing on time_sec
).toDF("user", "cat", "id", "time_sec")
Теперь более реалистично c MWE. В этом случае у нас может быть несколько пользователей и cat; К сожалению, этот подход не использует API-интерфейс dataframe и должен собирать данные для драйвера. Это не масштабируется и требует рекурсивного вызова getGrps
, отбрасывая длину возвращаемых индексов массива.
Как я могу реализовать это с помощью API dataframe, чтобы не нужно было собирать данные для драйвер, который был бы невозможен из-за размера? Кроме того, если есть лучший способ сделать это, что бы это было?
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.types.DoubleType
import scala.collection.mutable.WrappedArray
val dataDF = Seq(
("a", "silom", 3, 1),
("a", "silom", 2, 2),
("a", "silom", 1, 3),
("a", "silom", 0, 4),
("a", "silom", 1, 5),
("a", "silom", 0, 6),
("a", "silom", 1, 7),
("a", "silom", 2, 8),
("a", "silom", 3, 9),
("a", "silom", 4, 10),
("a", "silom", 3, 11),
("a", "silom", 4, 12),
("a", "silom", 3, 13),
("a", "silom", 4, 14),
("a", "silom", 5, 15),
("a", "suk", 18, 1),
("a", "suk", 19, 2),
("a", "suk", 20, 3),
("a", "suk", 21, 4),
("a", "suk", 20, 5),
("a", "suk", 21, 6),
("a", "suk", 0, 7),
("a", "suk", 1, 8),
("a", "suk", 2, 9),
("a", "suk", 3, 10),
("a", "suk", 4, 11),
("a", "suk", 3, 12),
("a", "suk", 4, 13),
("a", "suk", 3, 14),
("a", "suk", 5, 15),
("b", "silom", 4, 1),
("b", "silom", 3, 2),
("b", "silom", 2, 3),
("b", "silom", 1, 4),
("b", "silom", 0, 5),
("b", "silom", 1, 6),
("b", "silom", 2, 7),
("b", "silom", 3, 8),
("b", "silom", 4, 9),
("b", "silom", 5, 10),
("b", "silom", 6, 11),
("b", "silom", 7, 12),
("b", "silom", 8, 13),
("b", "silom", 9, 14),
("b", "silom", 10, 15),
("b", "suk", 11, 1),
("b", "suk", 12, 2),
("b", "suk", 13, 3),
("b", "suk", 14, 4),
("b", "suk", 13, 5),
("b", "suk", 14, 6),
("b", "suk", 13, 7),
("b", "suk", 12, 8),
("b", "suk", 11, 9),
("b", "suk", 10, 10),
("b", "suk", 9, 11),
("b", "suk", 8, 12),
("b", "suk", 7, 13),
("b", "suk", 6, 14),
("b", "suk", 5, 15)
).toDF("user", "cat", "id", "time_sec")
val recastDataDF = dataDF.withColumn("id", $"id".cast(DoubleType))
val category = recastDataDF.select("cat").distinct.collect.map(x => x(0).toString)
val data = recastDataDF
.select($"*" +: category.map(
name =>
lag("id", 1).over(
Window.partitionBy("user", "cat").orderBy("time_sec")
)
.alias(s"lag_${name}_id")): _*)
.withColumn("sequencing_diff", when($"cat" === "silom", ($"lag_silom_id" - $"id").cast(DoubleType))
.otherwise(($"lag_suk_id" - $"id")))
.drop("lag_silom_id", "lag_suk_id")
.withColumn("rn", row_number.over(Window.partitionBy("user", "cat").orderBy("time_sec")).cast(DoubleType))
.withColumn("zipped", array("user", "cat", "sequencing_diff", "rn", "id"))
// non dataframe API approach (not scalable)
// needs to collect data to driver to process
val iterTuples = data.select("zipped").collect.map(x => x(0).asInstanceOf[WrappedArray[Any]]).map(x => x.toArray)
val shifted: Array[Array[Any]] = iterTuples.drop(1)
val combined = iterTuples
.zipAll(shifted, Array("", "", Double.NaN, Double.NaN, Double.NaN), Array("", "", Double.NaN, Double.NaN, Double.NaN))
val testArr = combined.map{
case (data0, data1) =>
if(data1(3).toString.toDouble > 2 && data0(3).toString.toDouble > 2 && data1(0) == data0(0) && data1(1) == data0(1)) {
if(data0(2) != data1(2) && data0(2).toString.toDouble + data1(2).toString.toDouble == 0) {
(data1(0), data1(1), data1(3), data0(4))
}
else ("", "", Double.NaN, Double.NaN)
}
else ("", "", Double.NaN, Double.NaN)
}
.filter(t => t._1 != "" && t._2 != "" && t._3 == t._3 && t._4 == t._4) // fast NaN removal
val typeMappedArray = testArr.map(x => (x._1.toString, x._2.toString, x._3.toString.toDouble, x._4.toString.toDouble))
def getGrps(arr: Array[(String, String, Double, Double)]): (Array[Double], Double, String, String) = {
if(arr.nonEmpty) {
val user = arr.take(1)(0)._1
val cat = arr.take(1)(0)._2
val rowNum = arr.take(1)(0)._3
val keepID = arr.take(1)(0)._4
val newArr = arr.drop(1)
val rowNums = (Array(rowNum)) ++ newArr.zipWithIndex.map{
case (tups, idx) =>
if(rowNum + idx + 1 == tups._3) {
rowNum + 1 + idx
}
else Double.NaN
}
.filter(v => v == v)
(rowNums, keepID, user, cat)
}
else (Array(Double.NaN), Double.NaN, "", "")
}
// after overwriting, this would allow me to group by user, cat, id to sum the time
getGrps(typeMappedArray) // returns rows number to overwrite, value to overwrite id with, user, cat
res0: (Array(5.0, 6.0, 7.0),0.0,a,silom)
getGrps(typeMappedArray.drop(3))
res1: (Array(11.0, 12.0, 13.0, 14.0),4.0,a,silom)
Второй подход, использующий collect_list
, но он опирается на getGrps
рекурсивную работу, которую я не могу работать должным образом. Вот код, который у меня есть с измененным getGrps
для collect_list
минус рекурсивный.
val data = recastDataDF
.select($"*" +: category.map(
name =>
lag("id", 1).over(
Window.partitionBy("user", "cat").orderBy("time_sec")
)
.alias(s"lag_${name}_id")): _*)
.withColumn("sequencing_diff", when($"cat" === "silom", ($"lag_silom_id" - $"id").cast(DoubleType))
.otherwise(($"lag_suk_id" - $"id")))
.drop("lag_silom_id", "lag_suk_id")
.withColumn("rn", row_number.over(Window.partitionBy("user", "cat").orderBy("time_sec")).cast(DoubleType))
.withColumn("id_rn", array($"id", $"rn", $"sequencing_diff"))
.groupBy($"user", $"cat").agg(collect_list($"id_rn").alias("array_data"))
// collect one row to develop how the UDF would work
val testList = data.where($"user" === "a" && $"cat" === "silom").select("array_data").collect
.map(x => x(0).asInstanceOf[WrappedArray[WrappedArray[Any]]])
.map(x => x.toArray)
.head
.map(x => (x(0).toString.toDouble, x(1).toString.toDouble, x(2).asInstanceOf[Double]))
// this code would be in the UDF; that is, we would pass array_data to the UDF
scala.util.Sorting.stableSort(testList, (e1: (Double, Double, Double), e2: (Double, Double, Double)) => e1._2 < e2._2)
val shifted: Array[(Double, Double, Double)] = testList.drop(1)
val combined = testList
.zipAll(shifted, (Double.NaN, Double.NaN, Double.NaN), (Double.NaN, Double.NaN, Double.NaN))
val testArr = combined.map{
case (data0, data1) =>
if(data0._3 != data1._3 && data0._2 > 1) {
(data0._2, data0._1)
}
else (Double.NaN, Double.NaN)
}
.filter(t => t._1 == t._1 && t._1 == t._1)
// called inside the UDF
def getGrps2(arr: Array[(Double, Double)]): (Array[Double], Double) = {
// no need for user or cat
if(arr.nonEmpty) {
val rowNum = arr.take(1)(0)._1
val keepID = arr.take(1)(0)._2
val newArr = arr.drop(1)
val rowNums = (Array(rowNum)) ++ newArr.zipWithIndex.map{
case (tups, idx) =>
if(rowNum + idx + 1 == tups._1) {
rowNum + 1 + idx
}
else Double.NaN
}
.filter(v => v == v)
(rowNums, keepID)
}
else (Array(Double.NaN), Double.NaN)
}
Мы бы .withColumn("data_to_update", udf)
и столбец data_to_update
был бы WrappedArray[Tuple2[Array[Double], Double]]
с row_numbers для идентификатора для перезаписи. Результат для пользователя a
, cat silom
будет
WrappedArray((Array(4.0, 5.0, 6.0),0.0), (Array(10.0, 11.0, 12.0, 13.0),4.0))
Части массива - это номера строк, а Double - это id для обновления этих строк с