Квадрат кватерниона с использованием AVX - PullRequest
0 голосов
/ 24 апреля 2020

Кто-нибудь знает, как векторизовать эту функцию, используя AVX

void cuadradoYSumaNormal(quaternion* a, quaternion* b, quaternion* c) {
          c->w = a->w*a->w - a->x*a->x - a->y*a->y - a->z*a->z + b->w;
          c->x = 2.*a->w*a->x + b->x;
          c->y = 2.*a->w*a->y + b->y;
          c->z = 2.*a->w*a->z + b->z;
    }

Я могу предположить, что длина единицы для a, b и c

quaternion является следующей структурой :

struct quaternion{
  double w;
  double x;
  double y;
  double z;
};

Функция должна вывести квадрат кватерниона *a (используя правила умножения кватернинов), затем добавить кватернион *b и сохранить результат в *c.

1 Ответ

3 голосов
/ 25 апреля 2020

Это решение работает, если a имеет единичную длину, то есть aw^2+ax^2+ay^2+az^2 == 1

. В этом случае вычисление c->w эквивалентно вычислению 2*a->w*a->w - 1.0 + b->w, что значительно упрощает векторизацию. , Умножение на 2 может быть достигнуто путем добавления a (или a->w) к себе. Чтобы уменьшить задержку, необходимо добавить -1.0 к b->w. Возможная реализация:

inline __m256d unit(double value = 1.0)
{
    return _mm256_set_pd(0,0,0,value);
}

void cuadradoYSumaNormal_avx(quaternion* a, quaternion* b, quaternion* c) {

    __m256d aw = _mm256_broadcast_sd(&a->w);
    __m256d a_ = _mm256_loadu_pd(&a->w);
    __m256d b_ = _mm256_loadu_pd(&b->w);

    __m256d a_squared_plus_one = _mm256_mul_pd(aw, _mm256_add_pd(a_,a_));
    __m256d c_ = _mm256_add_pd(a_squared_plus_one, _mm256_add_pd(b_, unit(-1.0)));

    _mm256_storeu_pd(&c->w, c_);
}

Если помимо AVX у вас есть доступная FMA, вы можете присоединить некоторые сложения и умножения к

(aw * a + [-0.5,0,0,0]) * 2.0 + b

В результате всего две FMA (и одна трансляция и несколько нагрузок) , Возможная реализация:

void cuadradoYSumaNormal_fma(quaternion* a, quaternion* b, quaternion* c) {

    __m256d aw = _mm256_broadcast_sd(&a->w);
    __m256d a_ = _mm256_loadu_pd(&a->w);
    __m256d b_ = _mm256_loadu_pd(&b->w);

    __m256d a_squared_half = _mm256_fmadd_pd(aw, a_, unit(-0.5));
    __m256d c_ = _mm256_fmadd_pd(a_squared_half, _mm256_set1_pd(2.0), b_);

    _mm256_storeu_pd(&c->w, c_);
}
...