Как использовать скалярный pandas_udf в pyspark для столбцов типа массива - PullRequest
0 голосов
/ 22 марта 2019

Я пытаюсь реализовать скалярный pandas_udf в spark, но получаю ошибки при выполнении определенной операции. Ниже приведены подробности о структуре столбцов и udf, которые я написал:

dataframe schema for array type column:

list_col1: array (nullable = true)
 |    |-- element: string (containsNull = true)

from pyspark.sql import functions as F
from pyspark.sql.functions import udf, flatten, pandas_udf
from pyspark.sql.types import ArrayType, StringType, TimestampType
from pyspark.sql import Row


@pandas_udf(ArrayType(StringType()), PandasUDFType.SCALAR)
def truncate_data_udf(list_type_col, output_list_length):     
    sortedList=pd.Series(list_type_col).tolist()    
    un_list=list(OrderedDict.fromkeys(sortedList))
    trunc_size=int(output_list_length)   
    if len(un_list)>trunc_size:
        un_list=un_list[:trunc_size]
        un_list.insert(0, 'truncated')

    return pd.Series(un_list)

df = df.withColumn("list_col", truncate_data_udf(flatten(F.col("list_col1")), lit(10)))

Expected result is truncated list having elements equal to 10.

Так, в каком формате или тип данных входные данные передаются в pandas_udf. Если я хочу преобразовать данные входного столбца в список, то как я могу это сделать. И при возврате набора данных, как я могу вернуть результат в виде списка.

The result column should have schema like:
list_col1: array (nullable = true)
 |    |-- element: string (containsNull = true)

Я также написал обычный udf, как показано ниже, который работает, как и ожидалось. Но я хочу определить разницу в показателях между обычным и pandas_udf. Я считаю, что pandas_udf намного быстрее, чем обычный udf.

Normal udf:

def truncate_data(list_type_col, output_list_length): 
    l= list(OrderedDict.fromkeys(list_type_col))
    if l is not None and len(l) > output_list_length:
        l = l[:output_list_length]        
        l.insert(0, 'truncated')    
    return(l)

truncate_data_udf= udf(lambda row: truncate_data(row, output_list_length), ArrayType(StringType()))

df = df.withColumn("list_col", truncate_data_udf(flatten(F.col("list_col1"))))
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...