Сравнение двух значений в форме (a + sqrt (b)) как можно быстрее? - PullRequest
45 голосов
/ 08 мая 2019

Как часть программы, которую я пишу, мне нужно сравнить два значения в форме a + sqrt(b), где a и b - целые числа без знака.Поскольку это часть узкого цикла, я бы хотел, чтобы это сравнение выполнялось как можно быстрее.(Если это имеет значение, я запускаю код на компьютерах с архитектурой x86-64, а целые числа без знака не превышают 10 ^ 6. Кроме того, я точно знаю, что a1<a2.)

Какавтономная функция, это то, что я пытаюсь оптимизировать.Мои числа достаточно малы, чтобы double (или даже float) могли точно их представить, но ошибка округления в результатах sqrt не должна изменить результат.

// known pre-condition: a1 < a2  in case that helps
bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    return a1+sqrt(b1) < a2+sqrt(b2);  // computed mathematically exactly
}

Контрольный пример: is_smaller(900000, 1000000, 900001, 998002) должен возвращать true, но, как показано в комментариях @wim, вычисление его с sqrtf() вернет false.Так что (int)sqrt() будет усечено обратно до целого числа.

a1+sqrt(b1) = 90100 и a2+sqrt(b2) = 901000.00050050037512481206.Ближайшее значение с плавающей точкой - 90100.


Поскольку функция sqrt() обычно довольно дорогая, даже на современном x86-64, когда она полностью встроена как инструкция sqrtsd, я стараюсь избегатьвызов sqrt() настолько далеко, насколько это возможно.

Удаление sqrt путем возведения в квадрат потенциально также позволяет избежать любой опасности округления ошибок, делая все вычисления точными.

Если вместо этого функция была чем-то вроде этого ...

bool is_smaller(unsigned a1, unsigned b1, unsigned x) {
    return a1+sqrt(b1) < x;
}

... тогда я мог бы просто сделать return x-a1>=0 && static_cast<uint64_t>(x-a1)*(x-a1)>b1;

Но теперь, когда есть два sqrt(...) термина, я не могу делать те же алгебраические манипуляции.

Я мог бы возвести в квадрат значения дважды , используя эту формулу:

      a1 + sqrt(b1) = a2 + sqrt(b2)
<==>  a1 - a2 = sqrt(b2) - sqrt(b1)
<==>  (a1 - a2) * (a1 - a2) = b1 + b2 - 2 * sqrt(b1) * sqrt(b2)
<==>  (a1 - a2) * (a1 - a2) = b1 + b2 - 2 * sqrt(b1 * b2)
<==>  (a1 - a2) * (a1 - a2) - (b1 + b2) = - 2 * sqrt(b1 * b2)
<==>  ((b1 + b2) - (a1 - a2) * (a1 - a2)) / 2 = sqrt(b1 * b2)
<==>  ((b1 + b2) - (a1 - a2) * (a1 - a2)) * ((b1 + b2) - (a1 - a2) * (a1 - a2)) / 4 = b1 * b2

Беззнаковое деление на 4 дешево, потому что это всего лишь битное смещение, но так как я возведу в квадрат числа дваждынужно использовать 128-битные целые числа, и мне нужно будет ввести несколько проверок >=0 (потому что я сравниваю неравенство вместо равенства).

Такое ощущение, что может быть способ сделать это быстрее,применяя лучше ALGЭбра к этой проблеме.Есть ли способ сделать это быстрее?

Ответы [ 5 ]

18 голосов
/ 08 мая 2019

Вот версия без sqrt, хотя я не уверен, что она быстрее, чем версия, которая имеет только один sqrt (это может зависеть от распределения значений).

Вотматематика (как удалить оба sqrts):

ad = a2-a1
bd = b2-b1

a1+sqrt(b1) < a2+sqrt(b2)              // subtract a1
   sqrt(b1) < ad+sqrt(b2)              // square it
        b1  < ad^2+2*ad*sqrt(b2)+b2    // arrange
   ad^2+bd  > -2*ad*sqrt(b2)

Здесь правая сторона всегда отрицательна.Если левая сторона положительна, то мы должны вернуть true.

Если левая сторона отрицательна, тогда мы можем выровнять неравенство:

ad^4+bd^2+2*bd*ad^2 < 4*ad^2*b2

Ключевым моментом, на который следует обратить внимание, являетсячто если a2>=a1+1000, то is_smaller всегда возвращает true (поскольку максимальное значение sqrt(b1) равно 1000).Если a2<=a1+1000, то ad - это небольшое число, поэтому ad^4 всегда будет соответствовать 64-разрядному (нет необходимости в 128-разрядной арифметике).Вот код:

bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    int ad = a2 - a1;
    if (ad>1000) {
        return true;
    }

    int bd = b2 - b1;
    if (ad*ad+bd>0) {
        return true;
    }

    int ad2 = ad*ad;

    return (long long int)ad2*ad2 + (long long int)bd*bd + 2ll*bd*ad2 < 4ll*ad2*b2;
}

РЕДАКТИРОВАТЬ: Как заметил Питер Кордес, первый if не является необходимым, как второй, если обрабатывает его, поэтому код становится меньше и быстрее:

bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    int ad = a2 - a1;
    int bd = b2 - b1;
    if ((long long int)ad*ad+bd>0) {
        return true;
    }

    int ad2 = ad*ad;
    return (long long int)ad2*ad2 + (long long int)bd*bd + 2ll*bd*ad2 < 4ll*ad2*b2;
}
4 голосов
/ 08 мая 2019

Я устал и, вероятно, ошибся;но я уверен, что если я это сделаю, кто-то укажет на это ..

bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    a_diff = a1-a2;   // May be negative

    if(a_diff < 0) {
        if(b1 < b2) {
            return true;
        }
        temp = a_diff+sqrt(b1);
        if(temp < 0) {
            return true;
        }
        return temp*temp < b2;
    } else {
        if(b1 >= b2) {
            return false;
        }
    }
//  return a_diff+sqrt(b1) < sqrt(b2);

    temp = a_diff+sqrt(b1);
    return temp*temp < b2;
}

Если вы знаете a1 < a2, тогда это может стать:

bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    a_diff = a2-a1;    // Will be positive

    if(b1 > b2) {
        return false;
    }
    if(b1 >= a_diff*a_diff) {
        return false;
    }
    temp = a_diff+sqrt(b2);
    return b1 < temp*temp;
}
2 голосов
/ 08 мая 2019

Я не уверен, что алгебраические манипуляции в сочетании с целочисленной арифметикой обязательно приводят к быстрейшему решению.В этом случае вам понадобится много скалярных умножений (что не очень быстро), и / или предсказание ветвления может потерпеть неудачу, что может снизить производительность.Очевидно, что вам придется тестировать, чтобы увидеть, какое решение является самым быстрым в вашем конкретном случае.

Один из способов сделать sqrt немного быстрее - добавить параметр -fno-math-errno в gcc или clang.В этом случае компилятору не нужно проверять наличие отрицательных входных данных.При использовании icc это значение по умолчанию.

Более эффективное улучшение возможно при использовании векторизованной инструкции sqrt sqrtpd вместо скалярной инструкции sqrt sqrtsd.Питер Кордес показал , что clang способен автоматически векторизовать этот код, так что он генерирует этот sqrtpd.

Однако успешность автоматической векторизации в значительной степени зависит от правильных настроек компилятораи используемый компилятор (clang, gcc, icc и т. д.).При -march=nehalem или старше clang не векторизируется.

Более надежные результаты векторизации возможны с помощью следующего встроенного кода, см. Ниже.Для переносимости мы предполагаем только поддержку SSE2, которая является базовой для x86-64.

/* gcc -m64 -O3 -fno-math-errno smaller.c                      */
/* Adding e.g. -march=nehalem or -march=skylake might further  */
/* improve the generated code                                  */
/* Note that SSE2 in guaranteed to exist with x86-64           */
#include<immintrin.h>
#include<math.h>
#include<stdio.h>
#include<stdint.h>

int is_smaller_v5(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    uint64_t a64    =  (((uint64_t)a2)<<32) | ((uint64_t)a1); /* Avoid too much port 5 pressure by combining 2 32 bit integers in one 64 bit integer */
    uint64_t b64    =  (((uint64_t)b2)<<32) | ((uint64_t)b1); 
    __m128i ax      = _mm_cvtsi64_si128(a64);         /* Move integer from gpr to xmm register                  */
    __m128i bx      = _mm_cvtsi64_si128(b64);         
    __m128d a       = _mm_cvtepi32_pd(ax);            /* Convert 2 integers to double                           */
    __m128d b       = _mm_cvtepi32_pd(bx);            /* We don't need _mm_cvtepu32_pd since a,b < 1e6          */
    __m128d sqrt_b  = _mm_sqrt_pd(b);                 /* Vectorized sqrt: compute 2 sqrt-s with 1 instruction   */
    __m128d sum     = _mm_add_pd(a, sqrt_b);
    __m128d sum_lo  = sum;                            /* a1 + sqrt(b1) in the lower 64 bits                     */
    __m128d sum_hi  =  _mm_unpackhi_pd(sum, sum);     /* a2 + sqrt(b2) in the lower 64 bits                     */
    return _mm_comilt_sd(sum_lo, sum_hi);
}


int is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    return a1+sqrt(b1) < a2+sqrt(b2);
}


int main(){
    unsigned a1; unsigned b1; unsigned a2; unsigned b2;
    a1 = 11; b1 = 10; a2 = 10; b2 = 10;
    printf("smaller?  %i  %i \n",is_smaller(a1,b1,a2,b2), is_smaller_v5(a1,b1,a2,b2));
    a1 = 10; b1 = 11; a2 = 10; b2 = 10;
    printf("smaller?  %i  %i \n",is_smaller(a1,b1,a2,b2), is_smaller_v5(a1,b1,a2,b2));
    a1 = 10; b1 = 10; a2 = 11; b2 = 10;
    printf("smaller?  %i  %i \n",is_smaller(a1,b1,a2,b2), is_smaller_v5(a1,b1,a2,b2));
    a1 = 10; b1 = 10; a2 = 10; b2 = 11;
    printf("smaller?  %i  %i \n",is_smaller(a1,b1,a2,b2), is_smaller_v5(a1,b1,a2,b2));

    return 0;
}

См. эту ссылку Godbolt для созданной сборки.

В простом тесте пропускной способности на Intel Skylake с параметрами компилятора gcc -m64 -O3 -fno-math-errno -march=nehalem я обнаружил пропускную способность is_smaller_v5() который был в 2,6 раза лучше, чем исходный is_smaller(): 6,8 такта процессора против 18 циклов процессора, с учетом накладных расходов цикла.Однако в (слишком?) Простом тесте задержки, где входные данные a1, a2, b1, b2 зависели от результата предыдущего is_smaller(_v5), я не увидел никаких улучшений.(39,7 цикла против 39 циклов).

2 голосов
/ 08 мая 2019

Существует также метод Ньютона для вычисления целых sqrts как , описанный здесь Другой подход будет состоять в том, чтобы не вычислять квадратный корень, а искать слово (sqrt (n)) с помощью бинарного поиска ... есть "только «1000 полных квадратных чисел меньше 10 ^ 6.Это, вероятно, имеет плохую производительность, но будет интересным подходом.Я не измерял ни одного из них, но вот примеры:

#include <iostream>
#include <array>
#include <algorithm>        // std::lower_bound
#include <cassert>          


bool is_smaller_sqrt(unsigned a1, unsigned b1, unsigned a2, unsigned b2)
{
    return a1 + sqrt(b1) < a2 + sqrt(b2);
}

static std::array<int, 1001> squares;

template <typename C>
void squares_init(C& c)
{
    for (int i = 0; i < c.size(); ++i)
        c[i] = i*i;
}

inline bool greater(const int& l, const int& r)
{
    return r < l;
}

inline bool is_smaller_bsearch(unsigned a1, unsigned b1, unsigned a2, unsigned b2)
{
    // return a1 + sqrt(b1) < a2 + sqrt(b2)

    // find floor(sqrt(b1)) - binary search withing 1000 elems
    auto it_b1 = std::lower_bound(crbegin(squares), crend(squares), b1, greater).base();

    // find floor(sqrt(b2)) - binary search withing 1000 elems
    auto it_b2 = std::lower_bound(crbegin(squares), crend(squares), b2, greater).base();

    return (a2 - a1) > (it_b1 - it_b2);
}

unsigned int sqrt32(unsigned long n)
{
    unsigned int c = 0x8000;
    unsigned int g = 0x8000;

    for (;;) {
        if (g*g > n) {
            g ^= c;
        }

        c >>= 1;

        if (c == 0) {
            return g;
        }

        g |= c;
    }
}

bool is_smaller_sqrt32(unsigned a1, unsigned b1, unsigned a2, unsigned b2)
{
    return a1 + sqrt32(b1) < a2 + sqrt32(b2);
}

int main()
{
    squares_init(squares);

    // now can use is_smaller
    assert(is_smaller_sqrt(1, 4, 3, 1) == is_smaller_sqrt32(1, 4, 3, 1));
    assert(is_smaller_sqrt(1, 2, 3, 3) == is_smaller_sqrt32(1, 2, 3, 3));
    assert(is_smaller_sqrt(1000, 4, 1001, 1) == is_smaller_sqrt32(1000, 4, 1001, 1));
    assert(is_smaller_sqrt(1, 300, 3, 200) == is_smaller_sqrt32(1, 300, 3, 200));
}
1 голос
/ 08 мая 2019

Возможно, не лучше, чем другие ответы, но использует другую идею (и массу предварительного анализа).

// Compute approximate integer square root of input in the range [0,10^6].
// Uses a piecewise linear approximation to sqrt() with bounded error in each piece:
//   0 <= x <= 784 : x/28
//   784 < x <= 7056 : 21 + x/112
//   7056 < x <= 28224 : 56 + x/252
//   28224 < x <= 78400 : 105 + x/448
//   78400 < x <= 176400 : 168 + x/700
//   176400 < x <= 345744 : 245 + x/1008
//   345744 < x <= 614656 : 336 + x/1372
//   614656 < x <= 1000000 : (784000+x)/1784
// It is the case that sqrt(x) - 7.9992711366390365897... <= pseudosqrt(x) <= sqrt(x).
unsigned pseudosqrt(unsigned x) {
    return 
        x <= 78400 ? 
            x <= 7056 ?
                x <= 764 ? x/28 : 21 + x/112
              : x <= 28224 ? 56 + x/252 : 105 + x/448
          : x <= 345744 ?
                x <= 176400 ? 168 + x/700 : 245 + x/1008
              : x <= 614656 ? 336 + x/1372 : (x+784000)/1784 ;
}

// known pre-conditions: a1 < a2, 
//                  0 <= b1 <= 1000000
//                  0 <= b2 <= 1000000
bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
// Try three refinements:
// 1: a1 + sqrt(b1) <= a1 + 1000, 
//    so is a1 + 1000 < a2 ?  
//    Convert to a2 - a1 > 1000 .
// 2: a1 + sqrt(b1) <= a1 + pseudosqrt(b1) + 8 and
//    a2 + pseudosqrt(b2) <= a2 + sqrt(b2), 
//    so is  a1 + pseudosqrt(b1) + 8 < a2 + pseudosqrt(b2) ?
//    Convert to a2 - a1 > pseudosqrt(b1) - pseudosqrt(b2) + 8 .
// 3: Actually do the work.
//    Convert to a2 - a1 > sqrt(b1) - sqrt(b2)
// Use short circuit evaluation to stop when resolved.
    unsigned ad = a2 - a1;
    return (ad > 1000)
           || (ad > pseudosqrt(b1) - pseudosqrt(b2) + 8)
           || ((int) ad > (int)(sqrt(b1) - sqrt(b2)));
}

(у меня нет под рукой компилятора, так что, возможно, он содержит опечатку или две.)

...