Ваша ошибка вызвана передачей 3
и missing_fill
как литералов python в clip
. Как описано в в этом ответе , входные данные для udf
преобразуются в столбцы.
Вместо этого вы должны передавать литералы столбцов.
Вот упрощенный пример DataFrame:
example.show(truncate=False)
#+---+------------------------+--------------------+---------------+--------------------+
#|id |timestamp |thing1 |thing2 |thing3 |
#+---+------------------------+--------------------+---------------+--------------------+
#|A |[time_0, time_1, time_2]|[1.2, 1.1, 2.2] |[1.3, 1.5, 2.6]|[2.5, 3.4, 2.9] |
#|B |[time_0, time_1] |[5.1, 6.1, 1.4, 1.6]|[5.5, 6.2, 0.2]|[5.7, 6.3] |
#|C |[time_0, time_1] |[0.1, 0.2, 1.1] |[0.5, 0.3, 0.3]|[0.9, 0.6, 0.9, 0.4]|
#+---+------------------------+--------------------+---------------+--------------------+
Вам просто нужно внести одно небольшое изменение в аргументы, передаваемые udf
:
from pyspark.sql.functions import lit, udf
def clip_func(x, ts_len, backfill):
template = [backfill]*ts_len
template[-len(x):] = x
x = template
return x[-1 * ts_len:]
clip = udf(clip_func, ArrayType(DoubleType()))
ans = example
for c in [x for x in example.columns if 'thing' in x]:
missing_fill = 3.3
ans = ans.withColumn(c, clip(c, lit(3), lit(missing_fill)))
ans.show(truncate=False)
#+---+------------------------+---------------+---------------+---------------+
#|id |timestamp |thing1 |thing2 |thing3 |
#+---+------------------------+---------------+---------------+---------------+
#|A |[time_0, time_1, time_2]|[1.2, 1.1, 2.2]|[1.3, 1.5, 2.6]|[2.5, 3.4, 2.9]|
#|B |[time_0, time_1] |[6.1, 1.4, 1.6]|[5.5, 6.2, 0.2]|[3.3, 5.7, 6.3]|
#|C |[time_0, time_1] |[0.1, 0.2, 1.1]|[0.5, 0.3, 0.3]|[0.6, 0.9, 0.4]|
#+---+------------------------+---------------+---------------+---------------+
Как написано udf
:
- Если массив длиннее
ts_len
, он будет обрезать массив с начала (слева).
- Если массив короче
ts_len
, он добавит missing_fill
в начало массива.