Лучший способ транспонировать или поворачивать массив категориальных переменных для кодировки Spark ML - PullRequest
0 голосов
/ 25 января 2019

Я работаю над настройкой категориальных переменных для моделей с искровым ML. Вместо столбца с отдельными категориальными переменными у меня есть столбец с массивом категориальных переменных. См. Пример данных ниже.

(хотя это числа, они представляют категорию).

Мне нужно выделить их в отдельные функции, например, важно сохранить, чтобы # 1, # 3, # 6 и # 7 имели категорию 19, независимо от того, какие другие категории находятся в массиве.

Я мог бы использовать SQL, чтобы вручную определить все категориальные переменные и создать столбец для каждой из них. Но это не кажется элегантным, я думаю, что должен быть лучший способ, чтобы все категории были повернуты к столбцам, а затем обозначены как 1 или 0, которые затем могут быть закодированы в горячем виде. Или мне интересно, есть ли лучший способ подумать о проблеме?

Я использую scala 2.2.0 (и сейчас не могу обновить), поэтому не могу использовать более новые функции массива.

+---------------+----------------+
|id             |categorical_code|
+---------------+----------------+
|1              |           [19] |
|2              |       [87, 19] |
|3              |           [18] |
|4              |           [96] |
|5              |           [18] |
|6              |  [111, 22, 19] |
|7              |  [161, 19, 18] |
|8              |           [12] |
|9              |          [170] |
+---------------+----------------+

Требуется вывод (я думаю) что-то вроде:

id,cat_12,cat_18,cat_19,cat_22,cat_87,cat_111,cat_161,cat_170
1,,,1,,,,,
2,,,1,,1,,,
3,,1,,,,,,
4,,,,,,,,
5,,1,,,,,,
6,,,1,1,,1,1,
7,,1,1,,,,,
8,1,,,,,,,1
9,,,,,,,,

1 Ответ

0 голосов
/ 28 января 2019

Мы можем разбить массив на отдельные строки, а затем использовать groupby-pivot для получения требуемого вывода.

val df2 =
  df.
    select(
      df("id"),
      explode(df("categorical_code")).as("categorical_code"),
      lit(1).as("categorical_code_exist")
    )

df2.show()
+---+----------------+----------------------+
| id|categorical_code|categorical_code_exist|
+---+----------------+----------------------+
|  1|              19|                     1|
|  2|              87|                     1|
|  2|              19|                     1|
|  3|              18|                     1|
|  4|              96|                     1|
|  5|              18|                     1|
|  6|             111|                     1|
|  6|              22|                     1|
|  6|              19|                     1|
|  7|             161|                     1|
|  7|              19|                     1|
|  7|              18|                     1|
|  8|              12|                     1|
|  9|             170|                     1|
+---+----------------+----------------------+

val df3 =
  df2.
    groupBy("id").
    pivot("categorical_code").
    agg(coalesce(first(df2("categorical_code_exist")))).
    orderBy("id")

df3.show()
+---+----+----+----+----+----+----+----+----+----+
| id|  12|  18|  19|  22|  87|  96| 111| 161| 170|
+---+----+----+----+----+----+----+----+----+----+
|  1|null|null|   1|null|null|null|null|null|null|
|  2|null|null|   1|null|   1|null|null|null|null|
|  3|null|   1|null|null|null|null|null|null|null|
|  4|null|null|null|null|null|   1|null|null|null|
|  5|null|   1|null|null|null|null|null|null|null|
|  6|null|null|   1|   1|null|null|   1|null|null|
|  7|null|   1|   1|null|null|null|null|   1|null|
|  8|   1|null|null|null|null|null|null|null|null|
|  9|null|null|null|null|null|null|null|null|   1|
+---+----+----+----+----+----+----+----+----+----+

df3.printSchema()
root
 |-- id: integer (nullable = true)
 |-- 12: integer (nullable = true)
 |-- 18: integer (nullable = true)
 |-- 19: integer (nullable = true)
 |-- 22: integer (nullable = true)
 |-- 87: integer (nullable = true)
 |-- 96: integer (nullable = true)
 |-- 111: integer (nullable = true)
 |-- 161: integer (nullable = true)
 |-- 170: integer (nullable = true)

...