Проблема Pandas_udf: применить функцию к каждой строке, где данные ArrayType - PullRequest
0 голосов
/ 08 апреля 2019

У меня есть фрейм данных pyspark, в котором несколько столбцов содержат массивы различной длины.Я хочу перебрать соответствующие столбцы и обрезать массивы в каждой строке, чтобы они были одинаковой длины.В этом примере длина 3.

Это пример кадра данных:

id_1|id_2|id_3|        timestamp     |thing1       |thing2       |thing3
A   |b  |  c |[time_0,time_1,time_2]|[1.2,1.1,2.2]|[1.3,1.5,2.6|[2.5,3.4,2.9]
A   |b  |  d |[time_0,time_1]       |[5.1,6.1, 1.4, 1.6]    |[5.5,6.2, 0.2]   |[5.7,6.3]
A   |b  |  e |[time_0,time_1]       |[0.1,0.2, 1.1]    |[0.5,0.3, 0.3]   |[0.9,0.6, 0.9, 0.4]

Пока у меня есть,

 def clip_func(x, ts_len, backfill=1500):
     template = [backfill]*ts_len
     template[-len(x):] = x
     x = template
     return x[-1 * ts_len:]

clip = udf(clip_func, ArrayType(DoubleType()))

for c in [x for x in example.columns if 'thing' in x]:
    missing_fill = 3.3
    ans = ans.withColumn(c, clip(c, 3, missing_fill))

Но не работает.Если массив слишком короткий, я хочу заполнить массив значением missing_fill.

1 Ответ

1 голос
/ 08 апреля 2019

Ваша ошибка вызвана передачей 3 и missing_fill как литералов python в clip. Как описано в в этом ответе , входные данные для udf преобразуются в столбцы.

Вместо этого вы должны передавать литералы столбцов.

Вот упрощенный пример DataFrame:

example.show(truncate=False)
#+---+------------------------+--------------------+---------------+--------------------+
#|id |timestamp               |thing1              |thing2         |thing3              |
#+---+------------------------+--------------------+---------------+--------------------+
#|A  |[time_0, time_1, time_2]|[1.2, 1.1, 2.2]     |[1.3, 1.5, 2.6]|[2.5, 3.4, 2.9]     |
#|B  |[time_0, time_1]        |[5.1, 6.1, 1.4, 1.6]|[5.5, 6.2, 0.2]|[5.7, 6.3]          |
#|C  |[time_0, time_1]        |[0.1, 0.2, 1.1]     |[0.5, 0.3, 0.3]|[0.9, 0.6, 0.9, 0.4]|
#+---+------------------------+--------------------+---------------+--------------------+

Вам просто нужно внести одно небольшое изменение в аргументы, передаваемые udf:

from pyspark.sql.functions import lit, udf

def clip_func(x, ts_len, backfill):
    template = [backfill]*ts_len
    template[-len(x):] = x
    x = template
    return x[-1 * ts_len:]

clip = udf(clip_func, ArrayType(DoubleType()))

ans = example
for c in [x for x in example.columns if 'thing' in x]:
    missing_fill = 3.3
    ans = ans.withColumn(c, clip(c, lit(3), lit(missing_fill)))

ans.show(truncate=False)
#+---+------------------------+---------------+---------------+---------------+
#|id |timestamp               |thing1         |thing2         |thing3         |
#+---+------------------------+---------------+---------------+---------------+
#|A  |[time_0, time_1, time_2]|[1.2, 1.1, 2.2]|[1.3, 1.5, 2.6]|[2.5, 3.4, 2.9]|
#|B  |[time_0, time_1]        |[6.1, 1.4, 1.6]|[5.5, 6.2, 0.2]|[3.3, 5.7, 6.3]|
#|C  |[time_0, time_1]        |[0.1, 0.2, 1.1]|[0.5, 0.3, 0.3]|[0.6, 0.9, 0.4]|
#+---+------------------------+---------------+---------------+---------------+

Как написано udf:

  • Если массив длиннее ts_len, он будет обрезать массив с начала (слева).
  • Если массив короче ts_len, он добавит missing_fill в начало массива.
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...