Создание пользовательской функции в Spark для обработки столбца вложенной структуры - PullRequest
0 голосов
/ 20 апреля 2020

В моем фрейме данных у меня сложная структура данных, которую мне нужно обработать, чтобы обновить другой столбец. Подход, который я пробую, заключается в использовании UDF. Однако, если есть более простой способ сделать это, не стесняйтесь ответить на этот вопрос.

Рассматриваемая структура фрейма данных

root
 |-- user: string (nullable = true)
 |-- cat: string (nullable = true)
 |-- data_to_update: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- _1: array (nullable = true)
 |    |    |    |-- element: double (containsNull = false)
 |    |    |-- _2: double (nullable = false)

Проблема, которую я пытаюсь решить, Обновление столбца идентификатора при мерцании. Мерцание происходит, когда столбец идентификатора изменяется много раз от идентификатора к другому; например, это будет мерцание 4, 5, 4, 5, 4.

Объединение data и identifiedData, которое показано в том, как построить секцию данных, приводит к

+----+-----+----+--------+----+---------------------------------------------------------+
|user|cat  |id  |time_sec|rn  |data_to_update                                           |
+----+-----+----+--------+----+---------------------------------------------------------+
|a   |silom|3.0 |1       |1.0 |[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a   |silom|2.0 |2       |2.0 |[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a   |silom|1.0 |3       |3.0 |[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a   |silom|0.0 |4       |4.0 |[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a   |silom|1.0 |5       |5.0 |[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a   |silom|0.0 |6       |6.0 |[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a   |silom|1.0 |7       |7.0 |[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a   |silom|2.0 |8       |8.0 |[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a   |silom|3.0 |9       |9.0 |[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a   |silom|4.0 |10      |10.0|[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a   |silom|3.0 |11      |11.0|[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a   |silom|4.0 |12      |12.0|[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a   |silom|3.0 |13      |13.0|[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a   |silom|4.0 |14      |14.0|[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a   |silom|5.0 |15      |15.0|[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|

Что data_to_update говорит нам, что rn [4.0, 5.0, 6.0] необходимо обновить, изменив id на 0.0, а rn [10.0, 11.0, 12.0, 12.0] нужно обновить, изменив id на 4.0.


Моя попытка

Я думаю, что могу использовать UDF для обработки столбца data_to_update, используя withColumn для обновления столбца id. Тем не менее, я сложная структура данных является то, что вызывает проблемы для меня. Пока у меня есть

// I will call this in my UDF
def checkArray(clusterID: Double, colRN: Double, arr: Array[(Array[Double], Double)]): Double = {

    // @tailrec
    def getReturnID(clusterID: Double, colRN: Double, arr: Array[(Array[Double], Double)]): Double = arr match {

        case arr if arr.nonEmpty && arr(0)._1.contains(colRN) =>
            arr(0)._2
        case arr if arr.nonEmpty && !arr(0)._1.contains(colRN) =>
            getReturnID(clusterID, colRN, arr.drop(1))
        case _ => clusterID
    }

    getReturnID(clusterID, colRN, arr)
}

val columnUpdate: UserDefinedFunction = udf {
    (colID: Double, colRN: Double, colArrayData: Array[(Array[Double], Double)]) =>

    if(colArrayData.length > 0) {
        checkArray(colID, colRN, colArrayData)
    }
    else {
        colID
    }
}

data
    .join(broadcast(identifiedData), Seq("user", "cat"), "inner")
    .withColumn("id", columnUpdate($"id", $"rn", $"data_to_update"))
    .show(100, false)

Я не могу получить доступ к кортежу, который был преобразован в структуру.

org.apache.spark.SparkException: Failed to execute user defined function($anonfun$1: (double, double, array<struct<_1:array<double>,_2:double>>) => double)
    at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage7.processNext(Unknown Source)
    at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
    at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$10$$anon$1.hasNext(WholeStageCodegenExec.scala:614)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:253)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
    at org.apache.spark.scheduler.Task.run(Task.scala:109)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
    at java.lang.Thread.run(Thread.java:745)
Caused by: java.lang.ClassCastException: scala.collection.mutable.WrappedArray$ofRef cannot be cast to [Lscala.Tuple2;
    at $line329.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(<console>:37)

Из чтения Spark UDF для StructType / Row , я имитировал их настройку, но все еще безуспешно. Изменения, которые я сделал, были

def checkArray(clusterID: Double, 
               colRN: Double, 
               dataStruct: Row): Double = {

    // map back array to fill id value
    val arrData = dataStruct
        .getAs[Seq[Double]](0)
        .zipWithIndex.map{
            case (arr, idx) => (Array(arr), dataStruct.getAs[Seq[Double]](1)(idx))
        }

    // @tailrec
    def getReturnID(clusterID: Double, 
                    colRN: Double, 
                    arr: Seq[(Array[Double], Double)]): Double = arr match {

        case arr if arr.nonEmpty && arr(0)._1.contains(colRN) =>
            arr(0)._2
        case arr if arr.nonEmpty && !arr(0)._1.contains(colRN) =>
            getReturnID(clusterID, colRN, arr.drop(1))
        case _ => clusterID
    }

    getReturnID(clusterID, colRN, arrData)
}

val columnUpdate: UserDefinedFunction = udf {
    (colID: Double, colRN: Double, colStructData: Row) =>

    if(colStructData.getAs[Seq[Double]](0).nonEmpty) {
        checkArray(colID, colRN, colStructData)
    }
    else {
        colID
    }
}

Я передаю Row. Данные имеют вид

data.join(
    broadcast(identifiedData
        .withColumn("data_to_update", $"data_to_update")
        .withColumn("array_data", $"data_to_update._1")
        .withColumn("value", $"data_to_update._2")
        .withColumn("data_struct", struct("array_data", "value"))
        .drop("data_to_update", "array_data", "value")
             ),
              Seq("user", "cat"), 
              "inner"
    )

со следующей схемой

root
 |-- user: string (nullable = true)
 |-- cat: string (nullable = true)
 |-- id: double (nullable = false)
 |-- time_sec: integer (nullable = false)
 |-- rn: double (nullable = true)
 |-- data_struct: struct (nullable = false)
 |    |-- array_data: array (nullable = true)
 |    |    |-- element: array (containsNull = true)
 |    |    |    |-- element: double (containsNull = false)
 |    |-- value: array (nullable = true)
 |    |    |-- element: double (containsNull = true)

data.join(
    broadcast(identifiedData
        .withColumn("data_to_update", $"data_to_update")
        .withColumn("array", $"data_to_update._1")
        .withColumn("value", $"data_to_update._2")
        .withColumn("data_struct", struct("array", "value"))
        .drop("data_to_update", "array", "value")
             ),
              Seq("user", "cat"), 
              "inner"
    )
    .withColumn("id", columnUpdate($"id", $"rn", $"data_struct"))
    .show

При таком подходе я получаю следующую ошибку:

Caused by: java.lang.ClassCastException: scala.collection.mutable.WrappedArray$ofRef cannot be cast to java.lang.Double

Однако я Я передаю Row, а затем выполняю getAs для преобразования в необходимую структуру данных.


Как построить данные

import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.types.DoubleType
import scala.collection.mutable.WrappedArray
import org.apache.spark.sql.expressions.UserDefinedFunction
import scala.util.Sorting.stableSort
import org.apache.spark.sql.Row

val dataDF = Seq(
    ("a", "silom", 3, 1),
    ("a", "silom", 2, 2),
    ("a", "silom", 1, 3),
    ("a", "silom", 0, 4),
    ("a", "silom", 1, 5),
    ("a", "silom", 0, 6),
    ("a", "silom", 1, 7),
    ("a", "silom", 2, 8),
    ("a", "silom", 3, 9),
    ("a", "silom", 4, 10),
    ("a", "silom", 3, 11),
    ("a", "silom", 4, 12),
    ("a", "silom", 3, 13),
    ("a", "silom", 4, 14),
    ("a", "silom", 5, 15),
    ("a", "suk", 18, 1),
    ("a", "suk", 19, 2),
    ("a", "suk", 20, 3),
    ("a", "suk", 21, 4),
    ("a", "suk", 20, 5),
    ("a", "suk", 21, 6),
    ("a", "suk", 0, 7),
    ("a", "suk", 1, 8),
    ("a", "suk", 2, 9),
    ("a", "suk", 3, 10),
    ("a", "suk", 4, 11),
    ("a", "suk", 3, 12),
    ("a", "suk", 4, 13),
    ("a", "suk", 3, 14),
    ("a", "suk", 5, 15),
    ("b", "silom", 4, 1),
    ("b", "silom", 3, 2),
    ("b", "silom", 2, 3),
    ("b", "silom", 1, 4),
    ("b", "silom", 0, 5),
    ("b", "silom", 1, 6),
    ("b", "silom", 2, 7),
    ("b", "silom", 3, 8),
    ("b", "silom", 4, 9),
    ("b", "silom", 5, 10),
    ("b", "silom", 6, 11),
    ("b", "silom", 7, 12),
    ("b", "silom", 8, 13),
    ("b", "silom", 9, 14),
    ("b", "silom", 10, 15),
    ("b", "suk", 11, 1),
    ("b", "suk", 12, 2),
    ("b", "suk", 13, 3),
    ("b", "suk", 14, 4),
    ("b", "suk", 13, 5),
    ("b", "suk", 14, 6),
    ("b", "suk", 13, 7),
    ("b", "suk", 12, 8),
    ("b", "suk", 11, 9),
    ("b", "suk", 10, 10),
    ("b", "suk", 9, 11),
    ("b", "suk", 8, 12),
    ("b", "suk", 7, 13),
    ("b", "suk", 6, 14),
    ("b", "suk", 5, 15)
).toDF("user", "cat", "id", "time_sec")
val recastDataDF = dataDF.withColumn("id", $"id".cast(DoubleType))

val category = recastDataDF.select("cat").distinct.collect.map(x => x(0).toString)

val data = recastDataDF
    .withColumn("rn", row_number.over(Window.partitionBy("user", "cat").orderBy("time_sec")).cast(DoubleType))

val data2 = recastDataDF
    .select($"*" +: category.map(
        name => 
        lag("id", 1).over(
            Window.partitionBy("user", "cat").orderBy("time_sec")
        )
        .alias(s"lag_${name}_id")): _*)
    .withColumn("sequencing_diff", when($"cat" === "silom", ($"lag_silom_id" - $"id").cast(DoubleType))
                .otherwise(($"lag_suk_id" - $"id")))
    .drop("lag_silom_id", "lag_suk_id")
    .withColumn("rn", row_number.over(Window.partitionBy("user", "cat").orderBy("time_sec")).cast(DoubleType))
    .withColumn("id_rn", array($"id", $"rn", $"sequencing_diff"))
    .groupBy($"user", $"cat").agg(collect_list($"id_rn").alias("array_data"))

def getGrps2(arr: Array[(Double, Double)]): Array[(Array[Double], Double)] = {

    // @tailrec
    def returnAlternatingIDs(arr: Array[(Double, Double)], 
                             altIDs: Array[(Array[Double], Double)]): Array[(Array[Double], Double)] = arr match {

        case arr if arr.nonEmpty =>
            val rowNum = arr.take(1)(0)._1
            val keepID = arr.take(1)(0)._2
            val newArr = arr.drop(1)

            val rowNums = (Array(rowNum)) ++ newArr.zipWithIndex.map{
                case (tups, idx) => 
                if(rowNum + idx + 1 == tups._1) {
                    rowNum + 1 + idx
                }
                else {
                    Double.NaN
                }
            }
                .filter(v => v == v)

            val updateArray = altIDs ++ Array((rowNums, keepID))
            returnAlternatingIDs(arr.drop(rowNums.length), updateArray)
        case _ => altIDs
    }

    returnAlternatingIDs(arr, Array((Array(Double.NaN), Double.NaN))).drop(1)
}

val identifyFlickeringIDs: UserDefinedFunction = udf {
    (colArrayData: WrappedArray[WrappedArray[Double]]) =>
    val newArray: Array[(Double, Double, Double)] = colArrayData.toArray
        .map(x => (x(0).toDouble, x(1).toDouble, x(2).toDouble))

    // sort array by rn via less than relation
    stableSort(newArray, (e1: (Double, Double, Double), e2: (Double, Double, Double)) => e1._2 < e2._2)

    val shifted: Array[(Double, Double, Double)] = newArray.toArray.drop(1)
    val combined = newArray
        .zipAll(shifted, (Double.NaN, Double.NaN, Double.NaN), (Double.NaN, Double.NaN, Double.NaN))

    val parsedArray = combined.map{
        case (data0, data1) =>
        if(data0._3 != data1._3 && data0._2 > 1 && data0._3 + data1._3 == 0) {
            (data0._2, data0._1)
        }
        else (Double.NaN, Double.NaN)
        }
        .filter(t => t._1 == t._1 && t._1 == t._1)

    getGrps2(parsedArray).filter(data => data._1.length > 1)
}

val identifiedData = data2.withColumn("data_to_update", identifyFlickeringIDs($"array_data"))
    .drop("array_data")

1 Ответ

0 голосов
/ 20 апреля 2020

Я нашел решение, деконструируя столбец, поскольку он был в формате array<struct<array<double>, double>> и следовал Spark UDF для StructType / Row . Тем не менее, я считаю, что все еще может быть более краткий способ сделать это.

def checkArray(clusterID: Double, 
               colRN: Double, 
               dataStruct: Row): Double = {
    // Array[(Array[Double], Double)]
    val arrData: Seq[(Seq[Double], Double)] = dataStruct
        .getAs[Seq[Seq[Double]]](0)
        .zipWithIndex.map{
            case (arr, idx) => (arr, dataStruct.getAs[Seq[Double]](1)(idx))
        }

    // @tailrec
    def getReturnID(clusterID: Double, 
                    colRN: Double, 
                    arr: Seq[(Seq[Double], Double)]): Double = arr match {

        case arr if arr.nonEmpty && arr(0)._1.contains(colRN) =>
            arr(0)._2
        case arr if arr.nonEmpty && !arr(0)._1.contains(colRN) =>
            getReturnID(clusterID, colRN, arr.drop(1))
        case _ => clusterID
    }

    getReturnID(clusterID, colRN, arrData)
}

val columnUpdate: UserDefinedFunction = udf {
    (colID: Double, colRN: Double, colStructData: Row) =>

    if(colStructData.getAs[Seq[Double]](0).nonEmpty) {
        checkArray(colID, colRN, colStructData)
    }
    else {
        colID
    }
}

// I believe all this withColumns are unnecessary but this was the only way
// I could get a working solution
data.join(
    broadcast(identifiedData
        .withColumn("data_to_update", $"data_to_update")
        .withColumn("array", $"data_to_update._1")
        .withColumn("value", $"data_to_update._2")
        .withColumn("data_struct", struct("array", "value"))
        .drop("data_to_update", "array", "value")
             ),
              Seq("user", "cat"), 
              "inner"
    )
    .withColumn("id", columnUpdate($"id", $"rn", $"data_struct"))
    .show(100, false)
...