Так что мне удалось сделать это самостоятельно:
from pyspark.sql.types import StructType, StructField, DoubleType, StringType, IntegerType
from sklearn.cluster import DBSCAN
data = [(1, 11.6133, 48.1075),
(1, 11.6142, 48.1066),
(1, 11.6108, 48.1061),
(1, 11.6207, 48.1192),
(1, 11.6221, 48.1223),
(1, 11.5969, 48.1276),
(2, 11.5995, 48.1258),
(2, 11.6127, 48.1066),
(2, 11.6430, 48.1275),
(2, 11.6368, 48.1278),
(2, 11.5930, 48.1156)]
df = spark.createDataFrame(data, ["id", "X", "Y"])
output_schema = StructType(
[
StructField('id', StringType()),
StructField('X', DoubleType()),
StructField('Y', DoubleType()),
StructField('cluster', IntegerType())
]
)
@pandas_udf(output_schema, PandasUDFType.GROUPED_MAP)
def dbscan_pandas_udf(data):
data["cluster"] = DBSCAN(eps=5, min_samples=3).fit_predict(data[["X", "Y"]])
result = pd.DataFrame(data, columns=["id", "X", "Y", "cluster"])
return result
data.groupby("id").apply(dbscan_udf).show()
Надеюсь, это может кому-то помочь в будущем.