В моем фрейме данных у меня сложная структура данных, которую мне нужно обработать, чтобы обновить другой столбец. Подход, который я пробую, заключается в использовании UDF. Однако, если есть более простой способ сделать это, не стесняйтесь ответить на этот вопрос.
Рассматриваемая структура фрейма данных
root
|-- user: string (nullable = true)
|-- cat: string (nullable = true)
|-- data_to_update: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- _1: array (nullable = true)
| | | |-- element: double (containsNull = false)
| | |-- _2: double (nullable = false)
Проблема, которую я пытаюсь решить, Обновление столбца идентификатора при мерцании. Мерцание происходит, когда столбец идентификатора изменяется много раз от идентификатора к другому; например, это будет мерцание 4, 5, 4, 5, 4
.
Объединение data
и identifiedData
, которое показано в том, как построить секцию данных, приводит к
+----+-----+----+--------+----+---------------------------------------------------------+
|user|cat |id |time_sec|rn |data_to_update |
+----+-----+----+--------+----+---------------------------------------------------------+
|a |silom|3.0 |1 |1.0 |[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a |silom|2.0 |2 |2.0 |[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a |silom|1.0 |3 |3.0 |[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a |silom|0.0 |4 |4.0 |[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a |silom|1.0 |5 |5.0 |[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a |silom|0.0 |6 |6.0 |[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a |silom|1.0 |7 |7.0 |[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a |silom|2.0 |8 |8.0 |[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a |silom|3.0 |9 |9.0 |[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a |silom|4.0 |10 |10.0|[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a |silom|3.0 |11 |11.0|[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a |silom|4.0 |12 |12.0|[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a |silom|3.0 |13 |13.0|[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a |silom|4.0 |14 |14.0|[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a |silom|5.0 |15 |15.0|[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
Что data_to_update
говорит нам, что rn
[4.0, 5.0, 6.0]
необходимо обновить, изменив id
на 0.0
, а rn
[10.0, 11.0, 12.0, 12.0]
нужно обновить, изменив id
на 4.0
.
Моя попытка
Я думаю, что могу использовать UDF для обработки столбца data_to_update
, используя withColumn
для обновления столбца id
. Тем не менее, я сложная структура данных является то, что вызывает проблемы для меня. Пока у меня есть
// I will call this in my UDF
def checkArray(clusterID: Double, colRN: Double, arr: Array[(Array[Double], Double)]): Double = {
// @tailrec
def getReturnID(clusterID: Double, colRN: Double, arr: Array[(Array[Double], Double)]): Double = arr match {
case arr if arr.nonEmpty && arr(0)._1.contains(colRN) =>
arr(0)._2
case arr if arr.nonEmpty && !arr(0)._1.contains(colRN) =>
getReturnID(clusterID, colRN, arr.drop(1))
case _ => clusterID
}
getReturnID(clusterID, colRN, arr)
}
val columnUpdate: UserDefinedFunction = udf {
(colID: Double, colRN: Double, colArrayData: Array[(Array[Double], Double)]) =>
if(colArrayData.length > 0) {
checkArray(colID, colRN, colArrayData)
}
else {
colID
}
}
data
.join(broadcast(identifiedData), Seq("user", "cat"), "inner")
.withColumn("id", columnUpdate($"id", $"rn", $"data_to_update"))
.show(100, false)
Я не могу получить доступ к кортежу, который был преобразован в структуру.
org.apache.spark.SparkException: Failed to execute user defined function($anonfun$1: (double, double, array<struct<_1:array<double>,_2:double>>) => double)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage7.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$10$$anon$1.hasNext(WholeStageCodegenExec.scala:614)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:253)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
at org.apache.spark.scheduler.Task.run(Task.scala:109)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
at java.lang.Thread.run(Thread.java:745)
Caused by: java.lang.ClassCastException: scala.collection.mutable.WrappedArray$ofRef cannot be cast to [Lscala.Tuple2;
at $line329.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(<console>:37)
Из чтения Spark UDF для StructType / Row , я имитировал их настройку, но все еще безуспешно. Изменения, которые я сделал, были
def checkArray(clusterID: Double,
colRN: Double,
dataStruct: Row): Double = {
// map back array to fill id value
val arrData = dataStruct
.getAs[Seq[Double]](0)
.zipWithIndex.map{
case (arr, idx) => (Array(arr), dataStruct.getAs[Seq[Double]](1)(idx))
}
// @tailrec
def getReturnID(clusterID: Double,
colRN: Double,
arr: Seq[(Array[Double], Double)]): Double = arr match {
case arr if arr.nonEmpty && arr(0)._1.contains(colRN) =>
arr(0)._2
case arr if arr.nonEmpty && !arr(0)._1.contains(colRN) =>
getReturnID(clusterID, colRN, arr.drop(1))
case _ => clusterID
}
getReturnID(clusterID, colRN, arrData)
}
val columnUpdate: UserDefinedFunction = udf {
(colID: Double, colRN: Double, colStructData: Row) =>
if(colStructData.getAs[Seq[Double]](0).nonEmpty) {
checkArray(colID, colRN, colStructData)
}
else {
colID
}
}
Я передаю Row
. Данные имеют вид
data.join(
broadcast(identifiedData
.withColumn("data_to_update", $"data_to_update")
.withColumn("array_data", $"data_to_update._1")
.withColumn("value", $"data_to_update._2")
.withColumn("data_struct", struct("array_data", "value"))
.drop("data_to_update", "array_data", "value")
),
Seq("user", "cat"),
"inner"
)
со следующей схемой
root
|-- user: string (nullable = true)
|-- cat: string (nullable = true)
|-- id: double (nullable = false)
|-- time_sec: integer (nullable = false)
|-- rn: double (nullable = true)
|-- data_struct: struct (nullable = false)
| |-- array_data: array (nullable = true)
| | |-- element: array (containsNull = true)
| | | |-- element: double (containsNull = false)
| |-- value: array (nullable = true)
| | |-- element: double (containsNull = true)
data.join(
broadcast(identifiedData
.withColumn("data_to_update", $"data_to_update")
.withColumn("array", $"data_to_update._1")
.withColumn("value", $"data_to_update._2")
.withColumn("data_struct", struct("array", "value"))
.drop("data_to_update", "array", "value")
),
Seq("user", "cat"),
"inner"
)
.withColumn("id", columnUpdate($"id", $"rn", $"data_struct"))
.show
При таком подходе я получаю следующую ошибку:
Caused by: java.lang.ClassCastException: scala.collection.mutable.WrappedArray$ofRef cannot be cast to java.lang.Double
Однако я Я передаю Row
, а затем выполняю getAs
для преобразования в необходимую структуру данных.
Как построить данные
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.types.DoubleType
import scala.collection.mutable.WrappedArray
import org.apache.spark.sql.expressions.UserDefinedFunction
import scala.util.Sorting.stableSort
import org.apache.spark.sql.Row
val dataDF = Seq(
("a", "silom", 3, 1),
("a", "silom", 2, 2),
("a", "silom", 1, 3),
("a", "silom", 0, 4),
("a", "silom", 1, 5),
("a", "silom", 0, 6),
("a", "silom", 1, 7),
("a", "silom", 2, 8),
("a", "silom", 3, 9),
("a", "silom", 4, 10),
("a", "silom", 3, 11),
("a", "silom", 4, 12),
("a", "silom", 3, 13),
("a", "silom", 4, 14),
("a", "silom", 5, 15),
("a", "suk", 18, 1),
("a", "suk", 19, 2),
("a", "suk", 20, 3),
("a", "suk", 21, 4),
("a", "suk", 20, 5),
("a", "suk", 21, 6),
("a", "suk", 0, 7),
("a", "suk", 1, 8),
("a", "suk", 2, 9),
("a", "suk", 3, 10),
("a", "suk", 4, 11),
("a", "suk", 3, 12),
("a", "suk", 4, 13),
("a", "suk", 3, 14),
("a", "suk", 5, 15),
("b", "silom", 4, 1),
("b", "silom", 3, 2),
("b", "silom", 2, 3),
("b", "silom", 1, 4),
("b", "silom", 0, 5),
("b", "silom", 1, 6),
("b", "silom", 2, 7),
("b", "silom", 3, 8),
("b", "silom", 4, 9),
("b", "silom", 5, 10),
("b", "silom", 6, 11),
("b", "silom", 7, 12),
("b", "silom", 8, 13),
("b", "silom", 9, 14),
("b", "silom", 10, 15),
("b", "suk", 11, 1),
("b", "suk", 12, 2),
("b", "suk", 13, 3),
("b", "suk", 14, 4),
("b", "suk", 13, 5),
("b", "suk", 14, 6),
("b", "suk", 13, 7),
("b", "suk", 12, 8),
("b", "suk", 11, 9),
("b", "suk", 10, 10),
("b", "suk", 9, 11),
("b", "suk", 8, 12),
("b", "suk", 7, 13),
("b", "suk", 6, 14),
("b", "suk", 5, 15)
).toDF("user", "cat", "id", "time_sec")
val recastDataDF = dataDF.withColumn("id", $"id".cast(DoubleType))
val category = recastDataDF.select("cat").distinct.collect.map(x => x(0).toString)
val data = recastDataDF
.withColumn("rn", row_number.over(Window.partitionBy("user", "cat").orderBy("time_sec")).cast(DoubleType))
val data2 = recastDataDF
.select($"*" +: category.map(
name =>
lag("id", 1).over(
Window.partitionBy("user", "cat").orderBy("time_sec")
)
.alias(s"lag_${name}_id")): _*)
.withColumn("sequencing_diff", when($"cat" === "silom", ($"lag_silom_id" - $"id").cast(DoubleType))
.otherwise(($"lag_suk_id" - $"id")))
.drop("lag_silom_id", "lag_suk_id")
.withColumn("rn", row_number.over(Window.partitionBy("user", "cat").orderBy("time_sec")).cast(DoubleType))
.withColumn("id_rn", array($"id", $"rn", $"sequencing_diff"))
.groupBy($"user", $"cat").agg(collect_list($"id_rn").alias("array_data"))
def getGrps2(arr: Array[(Double, Double)]): Array[(Array[Double], Double)] = {
// @tailrec
def returnAlternatingIDs(arr: Array[(Double, Double)],
altIDs: Array[(Array[Double], Double)]): Array[(Array[Double], Double)] = arr match {
case arr if arr.nonEmpty =>
val rowNum = arr.take(1)(0)._1
val keepID = arr.take(1)(0)._2
val newArr = arr.drop(1)
val rowNums = (Array(rowNum)) ++ newArr.zipWithIndex.map{
case (tups, idx) =>
if(rowNum + idx + 1 == tups._1) {
rowNum + 1 + idx
}
else {
Double.NaN
}
}
.filter(v => v == v)
val updateArray = altIDs ++ Array((rowNums, keepID))
returnAlternatingIDs(arr.drop(rowNums.length), updateArray)
case _ => altIDs
}
returnAlternatingIDs(arr, Array((Array(Double.NaN), Double.NaN))).drop(1)
}
val identifyFlickeringIDs: UserDefinedFunction = udf {
(colArrayData: WrappedArray[WrappedArray[Double]]) =>
val newArray: Array[(Double, Double, Double)] = colArrayData.toArray
.map(x => (x(0).toDouble, x(1).toDouble, x(2).toDouble))
// sort array by rn via less than relation
stableSort(newArray, (e1: (Double, Double, Double), e2: (Double, Double, Double)) => e1._2 < e2._2)
val shifted: Array[(Double, Double, Double)] = newArray.toArray.drop(1)
val combined = newArray
.zipAll(shifted, (Double.NaN, Double.NaN, Double.NaN), (Double.NaN, Double.NaN, Double.NaN))
val parsedArray = combined.map{
case (data0, data1) =>
if(data0._3 != data1._3 && data0._2 > 1 && data0._3 + data1._3 == 0) {
(data0._2, data0._1)
}
else (Double.NaN, Double.NaN)
}
.filter(t => t._1 == t._1 && t._1 == t._1)
getGrps2(parsedArray).filter(data => data._1.length > 1)
}
val identifiedData = data2.withColumn("data_to_update", identifyFlickeringIDs($"array_data"))
.drop("array_data")