Ваша основная проблема связана с типом вывода UDF и доступом к элементам столбца. Вот как это решить, struct1
имеет решающее значение.
from pyspark.sql.types import ArrayType, StructField, StructType, DoubleType, StringType
from pyspark.sql import functions as F
# Define structures
struct1 = StructType([StructField("distCol", DoubleType(), True), StructField("url", StringType(), True)])
struct2 = StructType([StructField("urlA", StringType(), True), StructField("urlB", ArrayType(struct1), True)])
# Create DataFrame
df = spark.createDataFrame([
['url_a1', [[0.03, 'url1'], [0.02, 'url2'], [0.01, 'url3']]],
['url_a2', [[0.05, 'url4'], [0.03, 'url5']]]
], struct2)
Ввод:
+------+------------------------------------------+
|urlA |urlB |
+------+------------------------------------------+
|url_a1|[[0.03, url1], [0.02, url2], [0.01, url3]]|
|url_a2|[[0.05, url4], [0.03, url5]] |
+------+------------------------------------------+
UDF:
# Define udf
top_N = 5
def rank_url(array):
ranked_url = sorted(array, key=lambda x: x['distCol'])[0:top_N]
return ranked_url
url_udf = F.udf(rank_url, ArrayType(struct1))
# Apply udf
df2 = df.select('urlA', url_udf('urlB'))
df2.show(truncate=False)
Выход:
+------+------------------------------------------+
|urlA |rank_url(urlB) |
+------+------------------------------------------+
|url_a1|[[0.01, url3], [0.02, url2], [0.03, url1]]|
|url_a2|[[0.03, url5], [0.05, url4]] |
+------+------------------------------------------+