Сравнение всех строк в столбце со всеми другими строками в одном столбце (специальный запрос) - PullRequest
0 голосов
/ 16 мая 2019

В этом запросе мне дан кадр данных со столбцом 5d евклидовых точек (хранится в виде массива двойных чисел). Мне нужно найти все средние доступные расстояния. То есть для каждой точки a я вычисляю расстояние до другой точки b в кадре данных и нахожу среднее значение этих расстояний. Обратите внимание, что я не хочу никаких математических подходов или упрощений к этому вопросу. Фрейм данных имеет два столбца: unique_id и vector.

Я мог бы сделать запрос, но только по 1 пункту следующим образом. Расстояние UDF вычисляет расстояние между сохраненным массивом (то есть упакованным массивом) и данным массивом. Однако ясно, что этот подход работает только для одной точки. Также я попытался передать набор данных в статическую функцию. но каждый раз, когда я это делаю, я получаю «Invalid Tree: null», то есть объект становится нулевым, как только он входит в функцию ... Наконец, я подумал о создании UDAF, но я понял, что это не правильная агрегатная функция. Любая помощь в этом будет оценена!

(Примечание: этот код написан на Java, но он не должен сильно отличаться от других языков)

        long equal = 2;
        WrappedArray<Double> num = (WrappedArray<Double> spo.select("vectors")
       .filter(col("unique_id").equalTo(equal)).first().get(0);
        List<Double> frameList =  scala.collection.JavaConverters.seqAsJavaList(num);

        double[] array_answer = frameList.stream().mapToDouble(Double::doubleValue).toArray();
        UserDefinedFunction compare = udf(
                (WrappedArray<Double> array)  -> cosine_distance(array, array_answer),  DataTypes.DoubleType
        );
        double answer = (double) spo.select("vectors").filter(col("unique_id").notEqual(equal))
            .withColumn("calc", compare.apply(col("vectors")))
            .select(avg("calc")).first().get(0);
        System.out.println(answer);

1 Ответ

0 голосов
/ 16 мая 2019

Это можно сделать с помощью crossJoin.Вот 1d (псевдо) код в Scala:

val df = Seq(("a", 1), ("b", 2), ("c", 3)).toDF("unique_id", "vector")

df.select($"unique_id" as "id0", $"vector" as "vector0")
  .crossJoin(df.select($"unique_id" as "id1", $"vector" as "vector1"))
  .filter($"id0" =!= $"id1")
  .groupBy($"id0" as "unique_id")
  .agg(avg(
    abs($"vector0" - $"vector1") /*  use actual distance here */ ) as "mean_distance")
  .show()
+---------+-------------+
|unique_id|mean_distance|
+---------+-------------+
|        c|          1.5|
|        b|          1.0|
|        a|          1.5|
+---------+-------------+
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...