Почему _mm512_store_pd очень медленный в этом коде умножения матриц? - PullRequest
0 голосов
/ 28 января 2020

Я играю с avx512 и умножением матриц, но я, должно быть, делаю что-то не так, потому что у меня ужасные результаты, когда я пытаюсь сохранить свои результаты, используя _mm512_store_pd.

Вот соответствующие фрагменты кода, сначала структура данных, которую я использую, и как я ее инициализирую:


typedef struct {
        double* values;
        int nb_l;
        int nb_c;
} matrix;

matrix* alloc_matrix(int nb_l, int nb_c){
        matrix* tmp_matrix = (matrix*)malloc(sizeof(matrix));
        tmp_matrix->values = (double*)aligned_alloc(64, sizeof(double) * nb_l * nb_c);
        tmp_matrix->nb_l = nb_l;
        tmp_matrix->nb_c = nb_c;
        return tmp_matrix;
}

И вот как я пытаюсь умножить две инициализированные матрицы в другом месте в моем коде:

matrix* mult_matrix(matrix* A, matrix* B){
        /* avx512 */
        matrix* res_matrix = zero_matrix(A->nb_l, B->nb_c);
        double* res_ptr; // start index of the current line in res_matrix
        double* B_ptr; // start index of the current line in B

        __m512d A_broadcast, B_l_8, res_ptr_8;
        for (unsigned int idx_A = 0; idx_A < A->nb_l * A-> nb_c; idx_A++){
                // broadcast current value of A  eight times
                A_broadcast = _mm512_set1_pd(A->values[idx_A]);
                res_ptr = res_matrix->values + (idx_A / A->nb_c) * B->nb_c;
                B_ptr = B->values + (idx_A % A->nb_c) * B->nb_c;
                for (unsigned int offset_B = 0; offset_B < B->nb_c; offset_B+=8){
                        B_l_8 = _mm512_load_pd(&B_ptr[offset_B]);
                        res_ptr_8 = _mm512_load_pd(&res_ptr[offset_B]);
                        _mm512_store_pd(
                                &res_ptr[offset_B] , 
                                _mm512_fmadd_pd(A_broadcast, B_l_8, res_ptr_8)
                                );
                }
        }
        return res_matrix;

Результаты в порядке, но _mm512_store_pd занимает ~ 90% времени выполнения, на самом деле этот код avx512 чуть быстрее, чем его версия не AVX.

I ' Я перепробовал все, что мог придумать, но не могу понять, почему у меня такие неутешительные показатели с этим кодом У вас есть идеи?

Спасибо.

РЕДАКТИРОВАТЬ 1

Вот код не avx

        matrix* res_matrix = zero_matrix(A->nb_l, B->nb_c);
        double* res_ptr; // start index of the current line in res_matrix
        double* B_ptr; // start index of the current line in B

        for (unsigned int idx_A = 0; idx_A < A->nb_l * A-> nb_c; idx_A++){
                res_ptr = res_matrix->values + (idx_A / A->nb_c) * B->nb_c; 
                B_ptr = B->values + (idx_A % A->nb_c) * B->nb_c; 
                for (unsigned int offset_B = 0; offset_B < B->nb_c; offset_B++){                    
                        res_ptr[offset_B] += A->values[idx_A] * B_ptr[offset_B];
                }
        }
        return res_matrix;



Все матрицы представляют собой 512x512 случайных матриц, каждое умножение повторяется 50 раз, а время выполнения усредняется.

Наконец, приведенный ниже фрагмент должен быть в порядке, чтобы протестировать версии моего кода на avx и non_avx. Я скомпилировал его с помощью g cc 8.3.0, используя следующие параметры: g cc -Ofast -mavx -mavx512f -m64 -mfpmath = sse -mfma -flto -funroll-loops matrix_minimal. c

#include <stdio.h>
#include <stdlib.h>
#include <sys/time.h>
#include <time.h>
#include <string.h>
#include <immintrin.h>

typedef struct {
        double* values;
        int nb_l;
        int nb_c;
} matrix;

matrix* alloc_matrix(int nb_l, int nb_c){
        matrix* tmp_matrix = (matrix*)malloc(sizeof(matrix));
        tmp_matrix->values = (double*)aligned_alloc(64, sizeof(double) * nb_l * nb_c);
        tmp_matrix->nb_l = nb_l;
        tmp_matrix->nb_c = nb_c;
        return tmp_matrix;
}

void free_matrix(matrix** to_free){
        free((*to_free)->values);
        free(*to_free);
}

matrix* zero_matrix(int nb_l, int nb_c){
        matrix* z_matrix;
        z_matrix = alloc_matrix(nb_l, nb_c);
        for (int idx=0; idx < nb_l * nb_c; idx++){
                z_matrix->values[idx] = 0.0;
        }
        return z_matrix;
}
matrix* rand_matrix(int nb_l, int nb_c, double max_abs_val){
        static struct timeval seed; //static variables are zeroed at initialization
        matrix* rnd_matrix;
        rnd_matrix = alloc_matrix(nb_l, nb_c);

        if (seed.tv_sec == 0){ //ts_sec will never be zero after gettimeofday, whereas tv_usec could
                gettimeofday(&seed, NULL);
                srand((unsigned) seed.tv_usec);
        }
        for (int idx=0; idx < nb_l * nb_c; idx++){
                rnd_matrix->values[idx] = max_abs_val * ((double)rand() / RAND_MAX * 2.0 - 1.0);
        }

        return rnd_matrix;
}

matrix* mult_matrix_avx(matrix* A, matrix* B){
        /* pas trop mal en avx512 */
        matrix* res_matrix = zero_matrix(A->nb_l, B->nb_c);
        double* res_ptr; // start index of the current line in res_matrix
        double* B_ptr; // start index of the current line in B

        __m512d A_broadcast, B_l_8, res_ptr_8;
        for (unsigned int idx_A = 0; idx_A < A->nb_l * A-> nb_c; idx_A++){
                A_broadcast = _mm512_set1_pd(A->values[idx_A]); // broadcast current value of A eight times
                res_ptr = res_matrix->values + (idx_A / A->nb_c) * B->nb_c;
                B_ptr = B->values + (idx_A % A->nb_c) * B->nb_c;
                for (unsigned int offset_B = 0; offset_B < B->nb_c; offset_B+=8){
                        B_l_8 = _mm512_load_pd(&B_ptr[offset_B]);
                        res_ptr_8 = _mm512_load_pd(&res_ptr[offset_B]);
                        _mm512_store_pd(&res_ptr[offset_B] , _mm512_fmadd_pd(A_broadcast, B_l_8, res_ptr_8));
                }
        }
        return res_matrix;
}

matrix* mult_matrix(matrix* A, matrix* B){
        /* non avx512 */
        matrix* res_matrix = zero_matrix(A->nb_l, B->nb_c);
        double* res_ptr; // start index of the current line in res_matrix
        double* B_ptr; // start index of the current line in B

        for (unsigned int idx_A = 0; idx_A < A->nb_l * A-> nb_c; idx_A++){
                res_ptr = res_matrix->values + (idx_A / A->nb_c) * B->nb_c;
                B_ptr = B->values + (idx_A % A->nb_c) * B->nb_c;
                for (unsigned int offset_B = 0; offset_B < B->nb_c; offset_B++){
                        res_ptr[offset_B] += A->values[idx_A] * B_ptr[offset_B];
                }
        }
        return res_matrix;
}
int main(int argc, char *argv[]){
        struct timeval before;
        struct timeval after;

        matrix* A = rand_matrix(512, 512, 5);
        matrix* B = rand_matrix(512, 512, 5);
        matrix *C;
        gettimeofday(&before, NULL);
        for (int j=0; j<50;j++){
                C = mult_matrix_avx(A, B);
                free_matrix(&C); // we will measure the same overhead here and in the non avx version
        }
        gettimeofday(&after, NULL);
        double delta = ((after.tv_sec - before.tv_sec) * 1000000 +
                (after.tv_usec - before.tv_usec))/50;
        printf("avx %lf ms\n", delta);
        gettimeofday(&before, NULL);
        for (int j=0; j<50;j++){
                C = mult_matrix(A, B);
                free_matrix(&C); 
        }
        gettimeofday(&after, NULL);
        delta = ((after.tv_sec - before.tv_sec) * 1000000 +
                (after.tv_usec - before.tv_usec))/50;
        printf("non avx %lf ms\n", delta);

        free_matrix(&A);
        free_matrix(&B);
        return 0;
}

1 Ответ

0 голосов
/ 29 января 2020

@ chtz указал на очевидный ответ, я слишком часто обновлял результат, вместо того, чтобы полагаться на регистры mm512.

Если я сравниваю эту версию не avx:

        matrix* res_matrix = alloc_matrix(A->nb_l, B->nb_c);
        double tmp;
        for (unsigned int res_l=0; res_l < res_matrix->nb_l; res_l++){
                for (unsigned int res_c=0; res_c < res_matrix->nb_c; res_c++){
                        tmp = 0.0;
                        for (int offset = 0; offset < A->nb_c; offset++){
                                tmp += A->values[res_l * A->nb_c + offset] *
                                        B->values[offset * A->nb_c + res_c];
                        }
                        res_matrix->values[res_l * res_matrix->nb_c + res_c] = tmp;
                }

        }
        return res_matrix;

С его точным аналогом AVX:

        matrix* res_matrix = alloc_matrix(A->nb_l, B->nb_c);
        __m512d res_ptr_8;

        for (unsigned int res_l=0; res_l < res_matrix->nb_l; res_l++){
                for (unsigned int res_c=0; res_c < res_matrix->nb_c; res_c+=8){
                        // compute values from res_matrix[res_l, res_c] to [res_l, res_c+7]
                        res_ptr_8 = _mm512_set1_pd(0.0);
                        for (unsigned int offset_A_c = 0; offset_A_c < A->nb_c; offset_A_c++){
                                // on the res_l th line of A pick values one at a time
                                // at coordinates A[res_l, offset_A_c].
                                // Broadcast this value eight times into a mm512 vector
                                // and perform a dot product with the 8 values found in
                                // B from coordinates [offset_A_c, res_c] to [offset_A_c, res_c + 7]
                                res_ptr_8 = _mm512_fmadd_pd(
                                        _mm512_set1_pd(A->values[res_l * A->nb_c + offset_A_c]),
                                        _mm512_load_pd(&B->values[offset_A_c * B->nb_c + res_c]),
                                        res_ptr_8);
                        }
                        _mm512_store_pd(&res_matrix->values[res_l*res_matrix->nb_c + res_c] , res_ptr_8);
                }
        }

Время выполнения кода AVX в ~ 7 раз быстрее, чем не-AVX, что я и ожидал.

Однако, если сравнить это Новый код AVX для кода, отличного от AVX, вставлен в вопрос, увеличение скорости составляет всего около 3, что для меня все еще достаточно.

Редактировать

Эти приросты скорости измеряются на небольших блоках, чтобы поместиться в кэш, а не на матрицу 512 * 512, как в исходном сообщении.

Редактировать 2 Благодаря Питеру Кордесу (см. Ниже, отредактируйте: похоже, я совершенно не понял, что он сказал), вот обновленный код AVX, который в два раза быстрее, чем приведенный выше для блоков 48x48. Удивительно, но это даже быстрее, чем numpy / openblas на той же матрице.

        #define NB_L_STRIDE 8
        __m512d res_ptr_8[NB_L_STRIDE], B_ptr_8;
        for (unsigned int res_l=0; res_l < res_matrix->nb_l; res_l+=NB_L_STRIDE){
                for (unsigned int res_c=0; res_c < res_matrix->nb_c; res_c+=8){
                        for(unsigned int i=0; i<NB_L_STRIDE; i++)
                                res_ptr_8[i] = _mm512_setzero_pd();
                        for (unsigned int offset_A_c = 0; offset_A_c < A->nb_c; offset_A_c++){
                        // compute values from res_matrix[res_l, res_c] to [res_l, res_c+7]
                                // on the res_l th line of A pick values one at a time
                                // at coordinates A[res_l, offset_A_c].
                                // Broadcast this value eight times into a mm512 vector
                                // and perform a dot product with the 8 values found in
                                // B from coordinates [offset_A_c, res_c] to [offset_A_c, res_c + 7]
                                B_ptr_8 = _mm512_load_pd(&B->values[offset_A_c * B->nb_c + res_c]);
                                for(unsigned int i=0; i<NB_L_STRIDE; i++)
                                        res_ptr_8[i] = _mm512_fmadd_pd(
                                                _mm512_set1_pd(A->values[(res_l +i) * A->nb_c + offset_A_c]),
                                                B_ptr_8,
                                                res_ptr_8[i]);

                        }
                        for(unsigned int i=0; i<NB_L_STRIDE; i++)
                                _mm512_store_pd(
                                        &res_matrix->values[(res_l + i)*res_matrix->nb_c + res_c] ,
                                        res_ptr_8[i]);
                }
        }

...