Фильтрация элементов массива данных на основе внешнего массива с пересечением - PullRequest
3 голосов
/ 14 мая 2019

Я пытаюсь определить способ фильтрации элементов из WrappedArrays в DF. Фильтр основан на внешнем списке элементов.

В поисках решения я нашел этот вопрос . Это очень похоже, но, похоже, не работает для меня. Я использую Spark 2.4.0. Это мой код:

val df = sc.parallelize(Array((1, Seq("s", "v", "r")),(2, Seq("r", "a", "v")),(3, Seq("s", "r", "t")))).toDF("foo","bar")


def filterItems(flist: Seq[String]) = udf {
  (recs: Seq[String]) => recs match {
    case null => Seq.empty[String]
    case recs => recs.intersect(flist)
  }}

df.withColumn("filtercol", filterItems(Seq("s", "v"))(col("bar"))).show(5)

Мой ожидаемый результат будет:

+---+---------+---------+ 
|foo| bar|filtercol| 
+---+---------+---------+ 
| 1 |[s, v, r]|   [s, v]| 
| 2 |[r, a, v]|      [v]| 
| 3| [s, r, t]|      [s]| 
+---+---------+---------+

Но я получаю эту ошибку:

java.lang.ClassCastException: cannot assign instance of scala.collection.immutable.List$SerializationProxy to field org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$dependencies_ of type scala.collection.Seq in instance of org.apache.spark.rdd.MapPartitionsRDD

Ответы [ 2 ]

3 голосов
/ 14 мая 2019

Вы можете использовать встроенную функцию в Spark 2.4 без особых усилий:

import org.apache.spark.sql.functions.{array_intersect, array, lit}

val df = sc.parallelize(Array((1, Seq("s", "v", "r")),(2, Seq("r", "a", "v")),(3, Seq("s", "r", "t")))).toDF("foo","bar")

val ar = Seq("s", "v").map(lit(_))
df.withColumn("filtercol", array_intersect($"bar", array(ar:_*))).show

Выход:

+---+---------+---------+
|foo|      bar|filtercol|
+---+---------+---------+
|  1|[s, v, r]|   [s, v]|
|  2|[r, a, v]|      [v]|
|  3|[s, r, t]|      [s]|
+---+---------+---------+

Единственная сложная часть - Seq("s", "v").map(lit(_)), которая отображает каждую строку в lit(i). Функция intersection принимает два массива. Первым является значение столбца bar. Второй создает его на лету с array(ar:_*), который будет содержать значения lit(i).

0 голосов
/ 14 мая 2019

Если вы передаете атрибут ArrayType в UDF, он появляется как экземпляр WrappedArray, который не a List. Поэтому вы должны изменить тип recs на Seq, IndexedSeq или WrappedArray, обычно я просто использую обычный Seq:

def filterItems(flist: List[String]) = udf {
  (recs: Seq[String]) => recs match {
    case null => Seq.empty[String]
    case recs => recs.intersect(flist)
  }}
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...