Повышение производительности отдельных фрагментов кода в Java и Python - PullRequest
1 голос
/ 18 апреля 2011

У меня есть вопрос относительно производительности определенного фрагмента кода на Java и Python.

Алгоритм:
Я генерирую случайных N-мерных точек , а затем для всех точек, находящихся на определенном расстоянии друг от друга, выполняю некоторую обработку. Сама обработка здесь не важна, так как она не влияет на общее время выполнения. Генерация точек также занимает доли секунды в обоих случаях, поэтому меня интересует только та часть, которая выполняет сравнение.

Время выполнения:
Для фиксированного ввода в 3000 точек и 2 измерений Java делает это за от 2 до 4 секунд , тогда как Python занимает где-то от 15 до 200 секунд .

Я немного скептически отношусь ко времени выполнения Python. Что-то мне не хватает в этом коде Python? Есть ли какие-либо алгоритмические предложения по улучшению (например, предварительное выделение / повторное использование памяти, способ снижения сложности Big-Oh и т. Д.)?


Java

double random_points[][] = new double[number_of_points][dimensions];
for(i = 0; i < number_of_points; i++)
  for(d = 0; d < dimensions; d++)
    random_points[i][d] = Math.random();

double p1[], p2[];
for(i = 0; i < number_of_points; i++)
{
  p1 = random_points[i];
  for(j = i + 1; j < number_of_points; j++)
  {
    p2 = random_points[j];

    double sum_of_squares = 0;
    for(d = 0; d < DIM_; d++)
      sum_of_squares += (p2[d] - p1[d]) * (p2[d] - p1[d]);

    double distance = Math.sqrt(ss);
    if(distance > SOME_THRESHOLD) continue;

    //...else do something with p1 and p2

  }
}

Python 3,2

random_points = [[random.random() for _d in range(0,dimensions)] for _n in range(0,number_of_points)]

for i, p1 in enumerate(random_points):
  for j, p2 in enumerate(random_points[i+1:]):
    distance = math.sqrt(sum([(p1[d]-p2[d])**2 for d in range(0,dimensions)]))
    if distance > SOME_THRESHOLD: continue

    #...else do something with p1 and p2

Ответы [ 3 ]

5 голосов
/ 18 апреля 2011

Возможно, вы захотите использовать numpy .

Я только что попробовал следующее:

import numpy
from scipy.spatial.distance import pdist
D=2
N=3000
p=numpy.random.uniform(size=(N,D))
dist=pdist(p, 'euclidean')

Последняя строка вычисляет матрицу расстояний (это эквивалентно вычислению distance в вашем коде для каждой пары точек). На моем компьютере это занимает около 0,07 с.

Основным недостатком этого метода является то, что ему требуется O(n^2) память для матрицы расстояний. Если это проблема, лучшим вариантом может быть следующее:

for i in xrange(1, N):
  v = p[:N-i] - p[i:]
  dist = numpy.sqrt(numpy.sum(numpy.square(v), axis=1))
  for j in numpy.nonzero(dist > 1.4)[0]:
    print j, i+j

Для N = 3000 это занимает ~ 0,33 с на моем компьютере.

2 голосов
/ 18 апреля 2011

Если взять 30 тыс. Точек и 5 измерений, что в 100 раз больше работы.

int number_of_points = 30000;
int dimensions = 5;
double SOME_THRESHOLD = 0.1;

long start = System.nanoTime();
double random_points[][] = new double[number_of_points][dimensions];
for (int i = 0; i < number_of_points; i++)
    for (int d = 0; d < dimensions; d++)
        random_points[i][d] = Math.random();

double p1[], p2[];
Comparator<double[]> compareX = new Comparator<double[]>() {
    @Override
    public int compare(double[] o1, double[] o2) {
        return Double.compare(o1[0], o2[0]);
    }
};
Arrays.sort(random_points, compareX);

double[] key = new double[dimensions];
int count = 0;
for (int i = 0; i < number_of_points; i++) {
    p1 = random_points[i];
    key[0] = p1[0] + SOME_THRESHOLD;
    int index = Arrays.binarySearch(random_points, key, compareX);
    if (index < 0) index = ~index;
    NEXT: for (int j = i + 1; j < index; j++) {
        p2 = random_points[j];

        double sum_of_squares = 0;
        for (int d = 0; d < dimensions; d++) {
            sum_of_squares += (p2[d] - p1[d]) * (p2[d] - p1[d]);
            if (sum_of_squares > SOME_THRESHOLD * SOME_THRESHOLD) 
                continue NEXT;
        }

        //...else do something with p1 and p2
        count++;
    }
}
long time = System.nanoTime() - start;
System.out.println("Took " + time / 1000 / 1000 + " ms. count= " + count);

Печать

Took 1549 ms. count= 20197
1 голос
/ 18 апреля 2011

Действительно ли скорость имеет значение?Вот несколько очевидных ускорений:

  1. Не вычисляйте квадратный корень.Просто возведите в квадрат свой порог и сравните его с квадратным порогом.

  2. Сортируйте ваши точки по одному измерению (внешнему в вашей петле).Если две точки i и j находятся дальше, чем ваше пороговое значение только в этом измерении, то дальнейшее увеличение j приведет к созданию точек только дальше, чем это пороговое значение, и вы можете continue внешний цикл.

Могут быть и другие алгоритмические ускорения, даже выше все равно O (n d ).

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...