Передать столбец и карту в Scala UDF - PullRequest
2 голосов
/ 04 марта 2020

Я из Писпарка. Я знаю, как это сделать в Pyspark, но мне не удалось сделать то же самое в Scala.

Вот кадр данных,

val df = Seq(
  ("u1", Array[Int](2,3,4)),
  ("u2", Array[Int](7,8,9))
).toDF("id", "mylist")


//    +---+---------+
//    | id|   mylist|
//    +---+---------+
//    | u1|[2, 3, 4]|
//    | u2|[7, 8, 9]|
//    +---+---------+

, а вот объект Map,

val myMap = (1 to 4).toList.map(x => (x,0)).toMap

//myMap: scala.collection.immutable.Map[Int,Int] = Map(1 -> 0, 2 -> 0, 3 -> 0, 4 -> 0)

, поэтому эта карта имеет ключевые значения от 1 до 4.

Для каждой строки df я хочу проверить, содержится ли какой-либо элемент в «mylist» в myMap в качестве ключа ценность. Если myMap содержит элемент, вернуть этот элемент (вернуть любой, если содержится несколько элементов), в противном случае вернуть -1.

Таким образом, результат должен выглядеть следующим образом:

    +---+---------+-------+
    | id|   mylist| label|
    +---+---------+-------+
    | u1|[2, 3, 4]|    2  |
    | u2|[7, 8, 9]|    -1 |
    +---+---------+-------+

Я пробовал следующие подходы:

  1. ниже функция работает для объекта массива, но не работает для столбца:
def list2label(ls: Array[Int],
                m:  Map[Int, Int]):(Int) = {
                    var flag = 0
                    for (element <- ls) {
                        if (m.contains(element)) flag = element
                    }
                    flag
                }

val testls = Array[Int](2,3,4)
list2label(testls, myMap)

//testls: Array[Int] = Array(2, 3, 4)
//res33: Int = 4
пытался использовать UDF, но получил ошибку:
def list2label_udf(m: Map[Int, Int]) = udf( (ls: Array[Int]) =>(

                    var flag = 0
                    for (element <- ls) {
                        if (m.contains(element)) flag = element
                    }
                    flag
    )
)

//<console>:3: error: illegal start of simple expression
//                    var flag = 0
//                    ^

Я думаю, что мой udf в неправильном формате ..

в Pyspark Я могу сделать это, как я буду sh:
%pyspark

myDict={1:0, 2:0, 3:0, 4:0}

def list2label(ls, myDict):
    for i in ls:
        if i in dict3:
            return i
    return 0

def list2label_UDF(myDict):
     return udf(lambda c: list2label(c,myDict))

df = df.withColumn("label",list2label_UDF(myDict)(col("mylist")))

Любая помощь будет оценена!

1 Ответ

1 голос
/ 04 марта 2020

Решение показано ниже:

  scala> df.show
+---+---------+
| id|   mylist|
+---+---------+
| u1|[2, 3, 4]|
| u2|[7, 8, 9]|
+---+---------+


scala> def customUdf(m: Map[Int,Int]) = udf((s: Seq[Int]) => {
          val intersection = s.toList.intersect(m.keys.toList)
          if(intersection.isEmpty) -1 else intersection(0)})

customUdf: (m: Map[Int,Int])org.apache.spark.sql.expressions.UserDefinedFunction

scala> df.select($"id", $"myList", customUdf(myMap)($"myList").as("new_col")).show
+---+---------+-------+
| id|   myList|new_col|
+---+---------+-------+
| u1|[2, 3, 4]|      2|
| u2|[7, 8, 9]|     -1|
+---+---------+-------+

Другим подходом может быть отправка списка ключей карты вместо самой карты, поскольку ypu проверяет только ключи. Для этого решение приведено ниже:

scala> def customUdf1(m: List[Int]) = udf((s: Seq[Int]) => {
          val intersection = s.toList.intersect(m)
          if(intersection.isEmpty) -1 else intersection(0)})

customUdf1: (m: List[Int])org.apache.spark.sql.expressions.UserDefinedFunction

scala> df.select($"id",$"myList", customUdf1(myMap.keys.toList)($"myList").as("new_col")).show
+---+---------+-------+
| id|   myList|new_col|
+---+---------+-------+
| u1|[2, 3, 4]|      2|
| u2|[7, 8, 9]|     -1|
+---+---------+-------+

Дайте мне знать, если это поможет !!

...