Позвольте мне объяснить на примере, чего я хочу достичь.
Начиная с DataFrame следующим образом:
val df = Seq((1, "CS", 0, (0.1, 0.2, 0.4, 0.5)),
(4, "Ed", 0, (0.4, 0.8, 0.3, 0.6)),
(7, "CS", 0, (0.2, 0.5, 0.4, 0.7)),
(101, "CS", 1, (0.5, 0.7, 0.3, 0.8)),
(5, "CS", 1, (0.4, 0.2, 0.6, 0.9)))
.toDF("id", "dept", "test", "array")
+---+----+----+--------------------+
| id|dept|test| array|
+---+----+----+--------------------+
| 1| CS| 0|[0.1, 0.2, 0.4, 0.5]|
| 4| Ed| 0|[0.4, 0.8, 0.3, 0.6]|
| 7| CS| 0|[0.2, 0.5, 0.4, 0.7]|
|101| CS| 1|[0.5, 0.7, 0.3, 0.8]|
| 5| CS| 1|[0.4, 0.2, 0.6, 0.9]|
+---+----+----+--------------------+
Я хочу изменить некоторые элементы столбца массива в соответствии с информацией в столбце id, dept и test. Сначала я добавляю индекс в каждую строку для разных отделов следующим образом:
@transient val w = Window.partitionBy("dept").orderBy("id")
val tempdf = df.withColumn("Index", row_number().over(w))
tempdf.show
+---+----+----+--------------------+-----+
| id|dept|test| array|Index|
+---+----+----+--------------------+-----+
| 1| CS| 0|[0.1, 0.2, 0.4, 0.5]| 1|
| 5| CS| 1|[0.4, 0.2, 0.6, 0.9]| 2|
| 7| CS| 0|[0.2, 0.5, 0.4, 0.7]| 3|
|101| CS| 1|[0.5, 0.7, 0.3, 0.8]| 4|
| 4| Ed| 0|[0.4, 0.8, 0.3, 0.6]| 1|
+---+----+----+--------------------+-----+
Чего я хочу добиться, так это минус константу (0,1) от одного элемента в столбце массива, причем его местоположение соответствует индексу строки в каждом отделе. Например, в случае "dept == CS" конечный результат должен быть:
+---+----+----+--------------------+-----+
| id|dept|test| array|Index|
+---+----+----+--------------------+-----+
| 1| CS| 0|[0.0, 0.2, 0.4, 0.5]| 1|
| 5| CS| 1|[0.4, 0.1, 0.6, 0.9]| 2|
| 7| CS| 0|[0.2, 0.5, 0.3, 0.7]| 3|
|101| CS| 1|[0.5, 0.7, 0.3, 0.7]| 4|
| 4| Ed| 0|[0.4, 0.8, 0.3, 0.6]| 1|
+---+----+----+--------------------+-----+
В настоящее время я думаю о достижении этого с помощью udf следующим образом:
def subUdf = udf((array: Seq[Double], dampFactor: Double, additionalIndex: Int) => additionalIndex match{
case 0 => array
case _ => { val temp = array.zipWithIndex
var mask = Array.fill(array.length)(0.0)
mask(additionalIndex-1) = dampFactor
val tempAdj = temp.map(x => if (additionalIndex == (x._2+1)) (x._1-mask, x._2) else x)
tempAdj.map(_._1)
}
}
)
val dampFactor = 0.1
val finaldf = tempdf.withColumn("array", subUdf(tempdf("array"), dampFactor, when(tempdf("dept") === "CS" && tempdf("test") === 0, tempdf("Index")).otherwise(lit(0)))).drop("Index")
В файле udf есть ошибка компиляции из-за перегрузки:
Name: Compile Error
Message: <console>:34: error: overloaded method value - with alternatives:
(x: Double)Double <and>
(x: Float)Double <and>
(x: Long)Double <and>
(x: Int)Double <and>
(x: Char)Double <and>
(x: Short)Double <and>
(x: Byte)Double
cannot be applied to (Array[Double])
val tempAdj = temp.map(x => if (additionalIndex == (x._2+1)) (x._1-mask, x._2) else x)
^
Два связанных вопроса:
Как устранить ошибку компиляции?
Я открыт для использования метода, отличного от udf, для достижения этой цели.