Почему тип данных изменился при вызове UDF в scala - PullRequest
0 голосов
/ 13 марта 2019

У меня есть df:

joined.printSchema
root
 |-- cc_num: long (nullable = true)
 |-- lat: double (nullable = true)
 |-- long: double (nullable = true)
 |-- merch_lat: double (nullable = true)
 |-- merch_long: double (nullable = true)

У меня есть udf:

def getDistance (lat1:Double, lon1:Double, lat2:Double, lon2:Double) = {
    val r : Int = 6371 //Earth radius
    val latDistance : Double = Math.toRadians(lat2 - lat1)
    val lonDistance : Double = Math.toRadians(lon2 - lon1)
    val a : Double = Math.sin(latDistance / 2) * Math.sin(latDistance / 2) + Math.cos(Math.toRadians(lat1)) * Math.cos(Math.toRadians(lat2)) * Math.sin(lonDistance / 2) * Math.sin(lonDistance / 2)
    val c : Double = 2 * Math.atan2(Math.sqrt(a), Math.sqrt(1 - a))
    val distance : Double = r * c
    distance
  }

Мне нужно создать новый столбец для df с помощью:

joined = joined.withColumn("distance", getDistance("lat", "long", "merch_lat", "merch_long"))

Я получил ошибку ниже:

Name: Unknown Error
Message: <console>:35: error: type mismatch;
 found   : String("lat")
 required: Double
       joined = joined.withColumn("distance", getDistance("lat", "long", "merch_lat", "merch_long"))
                                                          ^
<console>:35: error: type mismatch;
 found   : String("long")
 required: Double
       joined = joined.withColumn("distance", getDistance("lat", "long", "merch_lat", "merch_long"))
                                                                 ^
<console>:35: error: type mismatch;
 found   : String("merch_lat")
 required: Double
       joined = joined.withColumn("distance", getDistance("lat", "long", "merch_lat", "merch_long"))
                                                                         ^
<console>:35: error: type mismatch;
 found   : String("merch_long")
 required: Double
       joined = joined.withColumn("distance", getDistance("lat", "long", "merch_lat", "merch_long"))
                                                                                      ^

Как видно из схемы, все задействованные поля имеют тип double, который соответствует определению типа параметра udf, поэтому я вижу несоответствие типа данныхошибка?

Может кто-нибудь просветить здесь, что не так и как это исправить?

Большое спасибо.

1 Ответ

2 голосов
/ 13 марта 2019

Ваш метод getDistance НЕ является UDF, это метод Scala, ожидающий 4 Double аргументов, и вместо этого вы передаете 4 строки.

Чтобы это исправить, вам необходимо:

  • "Оберните" ваш метод с помощью UDF и
  • Передайте column аргументы, а не Strings при применении UDF, что можно сделать, добавив к имени столбца префикс $
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions._
import spark.implicits._ // assuming "spark" is your SparkSession

val distanceUdf: UserDefinedFunction = udf(getDistance _)

joined.withColumn("distance", distanceUdf($"lat", $"long", $"merch_lat", $"merch_long"))
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...