Действительный udf до 2.4 (обратите внимание, что он должен что-то возвращать
from pyspark.sql.functions import udf
@udf("boolean")
def contains_all(x, y):
if x is not None and y is not None:
return set(y).issubset(set(x))
В версии 2.4 или более поздней не требуется udf:
from pyspark.sql.functions import array_intersect, size
def contains_all(x, y):
return size(array_intersect(x, y)) == size(y)
Использование:
from pyspark.sql.functions import col, sum as sum_, when
df1 = spark.createDataFrame(
[(1, [1, 2, 3, 5]), (2, [1, 2, 3, 6]), (3, [1, 2, 9, 8]), (4, [1, 2, 5, 6])],
("id", "transactions")
)
df2 = spark.createDataFrame(
[([1], 1.0), ([2], 1.0), ([2, 1], 2.0), ([6, 1], 2.0)],
("items", "cost")
)
(df1
.crossJoin(df2).groupBy("id", "transactions")
.agg(sum_(when(
contains_all("transactions", "items"), col("cost")
)).alias("score"))
.show())
Результат:
+---+------------+-----+
| id|transactions|score|
+---+------------+-----+
| 1|[1, 2, 3, 5]| 4.0|
| 4|[1, 2, 5, 6]| 6.0|
| 2|[1, 2, 3, 6]| 6.0|
| 3|[1, 2, 9, 8]| 4.0|
+---+------------+-----+
Если df2
мало, он может предпочесть использовать его как локальную переменную:
items = sc.broadcast([
(set(items), cost) for items, cost in df2.select("items", "cost").collect()
])
def score(y):
@udf("double")
def _(x):
if x is not None:
transactions = set(x)
return sum(
cost for items, cost in y.value
if items.issubset(transactions))
return _
df1.withColumn("score", score(items)("transactions")).show()
+---+------------+-----+
| id|transactions|score|
+---+------------+-----+
| 1|[1, 2, 3, 5]| 4.0|
| 2|[1, 2, 3, 6]| 6.0|
| 3|[1, 2, 9, 8]| 4.0|
| 4|[1, 2, 5, 6]| 6.0|
+---+------------+-----+
Наконец, возможновзорваться и присоединиться к
from pyspark.sql.functions import explode
costs = (df1
# Explode transactiosn
.select("id", explode("transactions").alias("item"))
.join(
df2
# Add id so we can later use it to identify source
.withColumn("_id", monotonically_increasing_id().alias("_id"))
# Explode items
.select(
"_id", explode("items").alias("item"),
# We'll need size of the original items later
size("items").alias("size"), "cost"),
["item"])
# Count matches in groups id, items
.groupBy("_id", "id", "size", "cost")
.count()
# Compute cost
.groupBy("id")
.agg(sum_(when(col("size") == col("count"), col("cost"))).alias("score")))
costs.show()
+---+-----+
| id|score|
+---+-----+
| 1| 4.0|
| 3| 4.0|
| 2| 6.0|
| 4| 6.0|
+---+-----+
, а затем соединить результат обратно с оригинальным df1
,
df1.join(costs, ["id"])
, но это гораздо менее простое решение и требует нескольких перемешиваний.по-прежнему предпочтительнее декартового произведения (crossJoin
), но это будет зависеть от фактических данных.