Если я правильно понимаю проблему, вы можете разделить фрейм данных на два фрейма данных на основе столбца class
, а затем объединить их на основе указанного предложения объединения (используя внешнее соединение):
from pyspark.sql.functions import col, collect_list, struct
A_df = df.where(col('class') == 'A').withColumnRenamed('name', 'A.name')
B_df = df.where(col('class') == 'B').withColumnRenamed('name', 'B.name')
join_clause = A_df.y - B_df.y <= 10
result = A_df.join(B_df, join_clause, 'outer')
И с результирующим фреймом данных преобразуйте два столбца в один столбец списка:
result = result.withColumn(collect_list(struct(col('A.name'), col('B.name')))
Обновление
Вот реализация чего-то с использованием mapPartitions
, нет присоединяется или преобразуется в DataFrame:
import math
from pyspark.sql import SparkSession
def process_data(rows):
r = 1000
joins = []
for row1 in rows:
for row2 in rows:
if row1['class'] != row2['class']:
if row1['x'] < row2['x'] + r:
if math.sqrt((row1['x'] - row2['x']) ** 2 + (row1['y'] - row2['y']) ** 2) <= r:
joins.append((row1['name'], row2['name']))
return joins
spark = SparkSession \
.builder \
.appName("Python Spark SQL basic example") \
.getOrCreate()
people = [('john', 35, 54, 'A'),
('george', 94, 84, 'B'),
('nicolas', 7, 9, 'B'),
('tom', 86, 93, 'A'),
('jason', 62, 73, 'B'),
('bill', 15, 58, 'A'),
('william', 9, 3, 'A'),
('brad', 73, 37, 'B'),
('cosmo', 52, 67, 'B'),
('jerry', 73, 30, 'A')]
fields = ('name', 'x', 'y', 'class')
data = [dict(zip(fields, person)) for person in people]
rdd = spark.sparkContext.parallelize(data)
result = rdd.mapPartitions(process_data).collect()
print(result)
Вывод:
[('tom', 'jason'), ('cosmo', 'jerry')]
Обновление 2
Добавлена начальная сортировка перейдите в поле 'y', переделите, чтобы убедиться, что все данные находятся в одном разделе (чтобы можно было сравнить все записи), и изменили вложенный l oop:
import math
from pyspark.sql import SparkSession
def process_data(rows):
r = 1000
joins = []
rows = list(rows)
for i, row1 in enumerate(rows):
for row2 in rows[i:]:
if row1['class'] != row2['class']:
if row1['x'] < row2['x'] + r:
if math.sqrt((row1['x'] - row2['x']) ** 2 + (row1['y'] - row2['y']) ** 2) < r:
joins.append((row1['name'], row2['name']))
return joins
spark = SparkSession \
.builder \
.appName("Python Spark SQL basic example") \
.getOrCreate()
people = [('john', 35, 54, 'A'),
('george', 94, 84, 'B'),
('nicolas', 7, 9, 'B'),
('tom', 86, 93, 'A'),
('jason', 62, 73, 'B'),
('bill', 15, 58, 'A'),
('william', 9, 3, 'A'),
('brad', 73, 37, 'B'),
('cosmo', 52, 67, 'B'),
('jerry', 73, 30, 'A')]
fields = ('name', 'x', 'y', 'class')
data = [dict(zip(fields, person)) for person in people]
rdd = spark.sparkContext.parallelize(data)
result = rdd.sortBy(lambda x: x['y'], ascending=False).repartition(1).mapPartitions(process_data).collect()
print(result)
Вывод:
[('william', 'nicolas'), ('william', 'brad'), ('william', 'cosmo'), ('william', 'jason'), ('william', 'george'), ('nicolas', 'jerry'), ('nicolas', 'john'), ('nicolas', 'bill'), ('nicolas', 'tom'), ('jerry', 'brad'), ('jerry', 'cosmo'), ('jerry', 'jason'), ('jerry', 'george'), ('brad', 'john'), ('brad', 'bill'), ('brad', 'tom'), ('john', 'cosmo'), ('john', 'jason'), ('john', 'george'), ('bill', 'cosmo'), ('bill', 'jason'), ('bill', 'george'), ('cosmo', 'tom'), ('jason', 'tom'), ('george', 'tom')]