Нельзя ссылаться на DataFrame внутри udf
. Как вы уже упоминали, эту проблему лучше всего решить с помощью join
.
IIUC, вы ищете что-то вроде:
from pyspark.sql import Window
import pyspark.sql.functions as F
df1.alias("L").join(df2.alias("R"), (df1.n == df2.x1) | (df1.n == df2.x2), how="left")\
.select("L.*", F.sum("w").over(Window.partitionBy("n")).alias("gamma"))\
.distinct()\
.show()
#+---+---+----------+----------+
#| n|val| distances| gamma|
#+---+---+----------+----------+
#| 1| 1|0.27308652|0.75747334|
#| 3| 1|0.21314497| null|
#| 2| 1|0.24969208|0.03103427|
#+---+---+----------+----------+
Или, если вам удобнее использовать синтаксис pyspark-sql
, вы можете зарегистрировать временные таблицы и выполнить:
df1.registerTempTable("df1")
df2.registerTempTable("df2")
sqlCtx.sql(
"SELECT DISTINCT L.*, SUM(R.w) OVER (PARTITION BY L.n) AS gamma "
"FROM df1 L LEFT JOIN df2 R ON L.n = R.x1 OR L.n = R.x2"
).show()
#+---+---+----------+----------+
#| n|val| distances| gamma|
#+---+---+----------+----------+
#| 1| 1|0.27308652|0.75747334|
#| 3| 1|0.21314497| null|
#| 2| 1|0.24969208|0.03103427|
#+---+---+----------+----------+
Объяснение
В обоих случаях мы выполняем левое соединение из df1
в df2
. Это сохранит все строки в df1
независимо от совпадения.
Предложение join - это условие, которое вы указали в своем вопросе. Таким образом, все строки в df2
, где x1
или x2
равно n
, будут объединены.
Затем выберите все строки из левой таблицы, плюс мы сгруппируем по (split by) n
и суммируем значения w
. Это получит сумму по всем строкам, которые соответствуют условию соединения, для каждого значения n
.
Наконец, мы возвращаем только отдельные строки для устранения дубликатов.