У меня есть следующий фрейм данных:
from pyspark.sql import SparkSession
sqlContext = SparkSession.builder.appName("test").enableHiveSupport().getOrCreate()
data = [(1,2,0.1,0.3),(1,2,0.1,0.3),(1,3,0.1,0.3),(1,3,0.1,0.3),
(11, 12, 0.1, 0.3),(11,12,0.1,0.3),(11,13,0.1,0.3),(11,13,0.1,0.3)]
trajectory_df = sqlContext.createDataFrame(data, schema=['grid_id','rider_id','lng','lat'])
trajectory_df.show()
+-------+--------+---+---+
|grid_id|rider_id|lng|lat|
+-------+--------+---+---+
| 1| 2|0.1|0.3|
| 1| 2|0.1|0.3|
| 1| 3|0.1|0.3|
| 1| 3|0.1|0.3|
| 11| 12|0.1|0.3|
| 11| 12|0.1|0.3|
| 11| 13|0.1|0.3|
| 11| 13|0.1|0.3|
+-------+--------+---+---+
Я хочу объединить данные из той же сетки в DICT. Где rider_id
- ключ dict, а широта и долгота - значение dict.
Результаты, которые я ожидаю, следующие:
[(1, {3:[[0.1, 0.3], [0.1, 0.3]],2:[[0.1, 0.3], [0.1, 0.3]]}),
(11,{13:[[0.1, 0.3], [0.1, 0.3]],12:[[0.1, 0.3], [0.1, 0.3]]})]
Я могу использовать groupByKey()
для группировки grid_id
.
def trans_point(row):
return ((row.grid_id, row.rider_id), [row.lng, row.lat])
trajectory_df = trajectory_df.rdd.map(trans_point).groupByKey().mapValues(list)
print(trajectory_df.take(10))
[((1, 3), [[0.1, 0.3], [0.1, 0.3]]), ((11, 13), [[0.1, 0.3], [0.1, 0.3]]), ((1, 2), [[0.1, 0.3], [0.1, 0.3]]), ((11, 12), [[0.1, 0.3], [0.1, 0.3]])]
Но я не могу получить результат, когда объединяю несколько dict:
trajectory_df = trajectory_df.map(lambda x:(x[0][0],{x[0][1]:x[1]})).reduceByKey(lambda x,y:x.update(y))
print(trajectory_df.take(10))
[(1, None), (11, None)]
Я надеюсь, что это сделано по типу RDD по некоторым причинам. Как мне этого добиться? Заранее спасибо.