Как рассчитать количество дублированных элементов во вложенном списке в PySpark? - PullRequest
2 голосов
/ 07 января 2020

У меня есть следующий DataFrame в PySpark:

+----------+------------------------+
|        id|              codes_list|
+----------+------------------------+
|      FF10|   [[1049, 1683], [108]]|
|      AB36|        [[1507], [1005]]|
|      8266|[[1049], [1049], [1049]]|
+----------+------------------------+

Это схема:

root
 |-- id: string (nullable = true)
 |-- codes_list: array (nullable = true)
 |    |-- element: string (containsNull = true)

Как рассчитать количество дублированных кодов в codes_list?

Это ожидаемый результат:

+----------+----+
|        id| qty|
+----------+----+
|      FF10|   0|
|      AB36|   0|
|      8266|   1|
+----------+----+

1 Ответ

0 голосов
/ 08 января 2020

Один простой способ - взорвать эти числа и считать каждое вхождение id, code. Затем сгруппируйте по id и используйте условную сумму, чтобы получить количество дублированных значений.

Поскольку массив содержит строки, а не подмассивы, вы можете, во-первых, удалить квадратные скобки и разделить на , чтобы получить коды.

data = [("FF10", ["[1049, 1683]", "[108]"]),
        ("FAB36", ["[1507]", "[1005]"]),
        ("8266", ["[1049]", "[1049]", "[1049]"])]

df = spark.createDataFrame(data, ["id", "codes_list"])


df.withColumn("codes", explode("codes_list")) \
  .withColumn("codes", explode(split(regexp_replace("codes", "[\\[\\]]", ""), ","))) \
  .groupBy("id", "codes").count() \
  .groupBy("id").agg(sum((col("count") > lit(1)).cast("int")).alias("qty")) \
  .show()

Дает:

+-----+---+
|   id|qty|
+-----+---+
| FF10|  0|
| 8266|  1|
|FAB36|  0|
+-----+---+
...