Группировка по чередующейся последовательности в Spark - PullRequest
2 голосов
/ 16 апреля 2020

У меня есть набор данных, где я могу определить чередующуюся последовательность. Однако я хочу сгруппировать эти данные в один фрагмент, оставив все остальные данные без изменений. То есть там, где когда-либо происходит мерцание в id, я хочу перезаписать эту группу идентификаторов с единственным идентификатором, который делает с тех пор в порядок. В качестве небольшого примера рассмотрим

val dataDF = Seq(
    ("a", "silom", 3, 1),
    ("a", "silom", 2, 2),
    ("a", "silom", 1, 3),
    ("a", "silom", 0, 4),  // flickering; id=0
    ("a", "silom", 1, 5),  // flickering; id=0
    ("a", "silom", 0, 6),  // flickering; id=0
    ("a", "silom", 1, 7),
    ("a", "silom", 2, 8),
    ("a", "silom", 3, 9),
    ("a", "silom", 4, 10),
    ("a", "silom", 3, 11),  // flickering and so on
    ("a", "silom", 4, 12),
    ("a", "silom", 3, 13),
    ("a", "silom", 4, 14),
    ("a", "silom", 5, 15)
).toDF("user", "cat", "id", "time_sec")

val resultDataDF = Seq(
    ("a", "silom", 3, 1),
    ("a", "silom", 2, 2),
    ("a", "silom", 1, 3),
    ("a", "silom", 0, 15),  // grouped by flickering summing on time_sec
    ("a", "silom", 1, 7),
    ("a", "silom", 2, 8),
    ("a", "silom", 3, 9),
    ("a", "silom", 4, 60),
    ("a", "silom", 5, 15). // grouped by flickering summing on time_sec
).toDF("user", "cat", "id", "time_sec")

Теперь более реалистично c MWE. В этом случае у нас может быть несколько пользователей и cat; К сожалению, этот подход не использует API-интерфейс dataframe и должен собирать данные для драйвера. Это не масштабируется и требует рекурсивного вызова getGrps, отбрасывая длину возвращаемых индексов массива.

Как я могу реализовать это с помощью API dataframe, чтобы не нужно было собирать данные для драйвер, который был бы невозможен из-за размера? Кроме того, если есть лучший способ сделать это, что бы это было?

import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.types.DoubleType
import scala.collection.mutable.WrappedArray

val dataDF = Seq(
    ("a", "silom", 3, 1),
    ("a", "silom", 2, 2),
    ("a", "silom", 1, 3),
    ("a", "silom", 0, 4),
    ("a", "silom", 1, 5),
    ("a", "silom", 0, 6),
    ("a", "silom", 1, 7),
    ("a", "silom", 2, 8),
    ("a", "silom", 3, 9),
    ("a", "silom", 4, 10),
    ("a", "silom", 3, 11),
    ("a", "silom", 4, 12),
    ("a", "silom", 3, 13),
    ("a", "silom", 4, 14),
    ("a", "silom", 5, 15),
    ("a", "suk", 18, 1),
    ("a", "suk", 19, 2),
    ("a", "suk", 20, 3),
    ("a", "suk", 21, 4),
    ("a", "suk", 20, 5),
    ("a", "suk", 21, 6),
    ("a", "suk", 0, 7),
    ("a", "suk", 1, 8),
    ("a", "suk", 2, 9),
    ("a", "suk", 3, 10),
    ("a", "suk", 4, 11),
    ("a", "suk", 3, 12),
    ("a", "suk", 4, 13),
    ("a", "suk", 3, 14),
    ("a", "suk", 5, 15),
    ("b", "silom", 4, 1),
    ("b", "silom", 3, 2),
    ("b", "silom", 2, 3),
    ("b", "silom", 1, 4),
    ("b", "silom", 0, 5),
    ("b", "silom", 1, 6),
    ("b", "silom", 2, 7),
    ("b", "silom", 3, 8),
    ("b", "silom", 4, 9),
    ("b", "silom", 5, 10),
    ("b", "silom", 6, 11),
    ("b", "silom", 7, 12),
    ("b", "silom", 8, 13),
    ("b", "silom", 9, 14),
    ("b", "silom", 10, 15),
    ("b", "suk", 11, 1),
    ("b", "suk", 12, 2),
    ("b", "suk", 13, 3),
    ("b", "suk", 14, 4),
    ("b", "suk", 13, 5),
    ("b", "suk", 14, 6),
    ("b", "suk", 13, 7),
    ("b", "suk", 12, 8),
    ("b", "suk", 11, 9),
    ("b", "suk", 10, 10),
    ("b", "suk", 9, 11),
    ("b", "suk", 8, 12),
    ("b", "suk", 7, 13),
    ("b", "suk", 6, 14),
    ("b", "suk", 5, 15)
).toDF("user", "cat", "id", "time_sec")
val recastDataDF = dataDF.withColumn("id", $"id".cast(DoubleType))

val category = recastDataDF.select("cat").distinct.collect.map(x => x(0).toString)

val data = recastDataDF
    .select($"*" +: category.map(
        name => 
        lag("id", 1).over(
            Window.partitionBy("user", "cat").orderBy("time_sec")
        )
        .alias(s"lag_${name}_id")): _*)
    .withColumn("sequencing_diff", when($"cat" === "silom", ($"lag_silom_id" - $"id").cast(DoubleType))
                .otherwise(($"lag_suk_id" - $"id")))
    .drop("lag_silom_id", "lag_suk_id")
    .withColumn("rn", row_number.over(Window.partitionBy("user", "cat").orderBy("time_sec")).cast(DoubleType))
    .withColumn("zipped", array("user", "cat", "sequencing_diff", "rn", "id"))

// non dataframe API approach (not scalable)
// needs to collect data to driver to process
val iterTuples = data.select("zipped").collect.map(x => x(0).asInstanceOf[WrappedArray[Any]]).map(x => x.toArray)

val shifted: Array[Array[Any]] = iterTuples.drop(1)
val combined = iterTuples
    .zipAll(shifted, Array("", "", Double.NaN, Double.NaN, Double.NaN), Array("", "", Double.NaN, Double.NaN, Double.NaN))

val testArr = combined.map{
    case (data0, data1) =>
    if(data1(3).toString.toDouble > 2 && data0(3).toString.toDouble > 2 && data1(0) == data0(0) && data1(1) == data0(1)) {
        if(data0(2) != data1(2) && data0(2).toString.toDouble + data1(2).toString.toDouble == 0) {
            (data1(0), data1(1), data1(3), data0(4))
        }
        else ("", "", Double.NaN, Double.NaN)
    }
    else ("", "", Double.NaN, Double.NaN)
}
    .filter(t => t._1 != "" && t._2 != "" && t._3 == t._3 && t._4 == t._4)  // fast NaN removal

val typeMappedArray = testArr.map(x => (x._1.toString, x._2.toString, x._3.toString.toDouble, x._4.toString.toDouble))

def getGrps(arr: Array[(String, String, Double, Double)]): (Array[Double], Double, String, String) = {

    if(arr.nonEmpty) {
        val user = arr.take(1)(0)._1
        val cat = arr.take(1)(0)._2
        val rowNum = arr.take(1)(0)._3
        val keepID = arr.take(1)(0)._4
        val newArr = arr.drop(1)

        val rowNums = (Array(rowNum)) ++ newArr.zipWithIndex.map{
            case (tups, idx) => 
            if(rowNum + idx + 1 == tups._3) {
                rowNum + 1 + idx
            }
            else Double.NaN
        }
            .filter(v => v == v)

        (rowNums, keepID, user, cat)
    }
    else (Array(Double.NaN), Double.NaN, "", "")
}


// after overwriting, this would allow me to group by user, cat, id to sum the time
getGrps(typeMappedArray)  // returns rows number to overwrite, value to overwrite id with, user, cat
res0: (Array(5.0, 6.0, 7.0),0.0,a,silom)

getGrps(typeMappedArray.drop(3))
res1: (Array(11.0, 12.0, 13.0, 14.0),4.0,a,silom)

Второй подход, использующий collect_list, но он опирается на getGrps рекурсивную работу, которую я не могу работать должным образом. Вот код, который у меня есть с измененным getGrps для collect_list минус рекурсивный.

val data = recastDataDF
    .select($"*" +: category.map(
        name => 
        lag("id", 1).over(
            Window.partitionBy("user", "cat").orderBy("time_sec")
        )
        .alias(s"lag_${name}_id")): _*)
    .withColumn("sequencing_diff", when($"cat" === "silom", ($"lag_silom_id" - $"id").cast(DoubleType))
                .otherwise(($"lag_suk_id" - $"id")))
    .drop("lag_silom_id", "lag_suk_id")
    .withColumn("rn", row_number.over(Window.partitionBy("user", "cat").orderBy("time_sec")).cast(DoubleType))
    .withColumn("id_rn", array($"id", $"rn", $"sequencing_diff"))
    .groupBy($"user", $"cat").agg(collect_list($"id_rn").alias("array_data"))

// collect one row to develop how the UDF would work
val testList = data.where($"user" === "a" && $"cat" === "silom").select("array_data").collect
    .map(x => x(0).asInstanceOf[WrappedArray[WrappedArray[Any]]])
    .map(x => x.toArray)
    .head
    .map(x => (x(0).toString.toDouble, x(1).toString.toDouble, x(2).asInstanceOf[Double]))

// this code would be in the UDF; that is, we would pass array_data to the UDF
scala.util.Sorting.stableSort(testList, (e1: (Double, Double, Double), e2: (Double, Double, Double)) => e1._2 < e2._2)

val shifted: Array[(Double, Double, Double)] = testList.drop(1)
val combined = testList
    .zipAll(shifted, (Double.NaN, Double.NaN, Double.NaN), (Double.NaN, Double.NaN, Double.NaN))

val testArr = combined.map{
    case (data0, data1) =>
    if(data0._3 != data1._3 && data0._2 > 1) {
        (data0._2, data0._1)
    }
    else (Double.NaN, Double.NaN)
    }
    .filter(t => t._1 == t._1 && t._1 == t._1) 

// called inside the UDF
def getGrps2(arr: Array[(Double, Double)]): (Array[Double], Double) = {
    // no need for user or cat

    if(arr.nonEmpty) {
        val rowNum = arr.take(1)(0)._1
        val keepID = arr.take(1)(0)._2
        val newArr = arr.drop(1)

        val rowNums = (Array(rowNum)) ++ newArr.zipWithIndex.map{
            case (tups, idx) => 
            if(rowNum + idx + 1 == tups._1) {
                rowNum + 1 + idx
            }
            else Double.NaN
        }
            .filter(v => v == v)

        (rowNums, keepID)
    }
    else (Array(Double.NaN), Double.NaN)
}

Мы бы .withColumn("data_to_update", udf) и столбец data_to_update был бы WrappedArray[Tuple2[Array[Double], Double]] с row_numbers для идентификатора для перезаписи. Результат для пользователя a, cat silom будет

WrappedArray((Array(4.0, 5.0, 6.0),0.0), (Array(10.0, 11.0, 12.0, 13.0),4.0))

Части массива - это номера строк, а Double - это id для обновления этих строк с

1 Ответ

0 голосов
/ 17 апреля 2020

Следующий рекурсивный метод, примененный в UDF, работающем со столбцом array_data, создаст желаемые результаты

def getGrps(arr: Array[(Double, Double)]): Array[(Array[Double], Double)] = {

    def returnAlternatingIDs(arr: Array[(Double, Double)], 
                             altIDs: Array[(Array[Double], Double)]): Array[(Array[Double], Double)] = arr match {

        case arr if arr.nonEmpty =>
            val rowNum = arr.take(1)(0)._1
            val keepID = arr.take(1)(0)._2
            val newArr = arr.drop(1)

            val rowNums = (Array(rowNum)) ++ newArr.zipWithIndex.map{
                case (tups, idx) => 
                if(rowNum + idx + 1 == tups._1) {
                    rowNum + 1 + idx
                }
                else {
                    Double.NaN
                }
            }
                .filter(v => v == v)

            val updateArray = altIDs ++ Array((rowNums, keepID))
            returnAlternatingIDs(arr.drop(rowNums.length), updateArray)
        case _ => altIDs
    }

    returnAlternatingIDs(arr, Array((Array(Double.NaN), Double.NaN))).drop(1)
}

Возвращаемое значение для первого collect_list равно Array((Array(5.0, 6.0, 7.0),0.0), (Array(11.0, 12.0, 13.0, 14.0),4.0)) по желанию.

Полный UDF

val identifyFlickeringIDs: UserDefinedFunction = udf {
    (colArrayData: WrappedArray[WrappedArray[Double]]) =>
    val newArray: Array[(Double, Double, Double)] = colArrayData.toArray
        .map(x => (x(0).toDouble, x(1).toDouble, x(2).toDouble))

    // sort array by rn via less than relation
    stableSort(newArray, (e1: (Double, Double, Double), e2: (Double, Double, Double)) => e1._2 < e2._2)

    val shifted: Array[(Double, Double, Double)] = newArray.toArray.drop(1)
    val combined = newArray
        .zipAll(shifted, (Double.NaN, Double.NaN, Double.NaN), (Double.NaN, Double.NaN, Double.NaN))

    val parsedArray = combined.map{
        case (data0, data1) =>
        if(data0._3 != data1._3 && data0._2 > 1 && data0._3 + data1._3 == 0) {
            (data0._2, data0._1)
        }
        else (Double.NaN, Double.NaN)
        }
        .filter(t => t._1 == t._1 && t._1 == t._1)

    getGrps(parsedArray).filter(data => data._1.length > 1)
}
...