Как оптимизировать конвейерные задержки из-за кода L1D Cache Bound в Haswell для многократного ядра Matrix-Matrix? - PullRequest
2 голосов
/ 06 февраля 2020

Я пытаюсь оптимизировать несколько ядер для AVX2, чтобы углубить свое понимание оптимизации микро-архитектурных программ. Одно из ядер, которое я пишу, - это матричное умножение. Я профилировал выполнение с Intel VTune и обнаружил, что код привязан к L1D Cache. Я провел некоторое исследование о том, что это значит, и наткнулся на эту статью от Intel, где изложены некоторые возможные причины. Прочитав различные возможности, я решил некоторые проблемы, добавив аккумулятор для предотвращения более длинной цепочки зависимостей. Я заглянул в сгенерированную сборку, но я не слишком разбираюсь в коде сборки.

Моя реализация выполняет примерно на 50% производительность DGEMM MKL для тех же входов. Это совпадает с VTune, который указывает на 50% производительности остановок конвейера памяти.

Код тратит более 90% времени в следующем ядре:


#include <xmmintrin.h> //AVX2

#define MU 16
#define NU 16
#define KU 8
#define MINIMUM(a,b) (((a)<(b))?(a):(b))

static inline void tiny_dgemm_ijk(const double* A, const double* B, double *restrict C, const int i, const int j, const int k, const int lda, const int ldb, const int ldc){
    __m256d cvec1, cvec2, cvec3, cvec4, cvec5, cvec6, cvec7, cvec8;
    __m256d cvec9, cvec10, cvec11, cvec12, cvec13, cvec14, cvec15, cvec16;
    __m256d avec1, avec2, avec3, avec4, avec5, avec6, avec7, avec8;
    __m256d bvec1, bvec2;
    // Iterate over a chunk of NU columns of C.
    for (int jjj = 0; (jjj < NU); jjj++){
        // Load a chunk of C of dimensions 16 by 1
        cvec1 = _mm256_loadu_pd((__m256d*) (C + i + (jjj+j)*ldc + 0));
        cvec2 = _mm256_loadu_pd((__m256d*) (C + i + (jjj+j)*ldc + 4));
        cvec3 = _mm256_loadu_pd((__m256d*) (C + i + (jjj+j)*ldc + 8));
        cvec4 = _mm256_loadu_pd((__m256d*) (C + i + (jjj+j)*ldc + 12));
        cvec5 = _mm256_set1_pd(0.0);
        cvec6 = _mm256_set1_pd(0.0);
        cvec7 = _mm256_set1_pd(0.0);
        cvec8 = _mm256_set1_pd(0.0);
        // Iterate over KU columns of A and rows of B
        // Perform FMA operation over all 16 rows of A corresponding to the 16 rows of C.
        for (int kkk = 0; (kkk < KU); kkk+=2){
            bvec1 = _mm256_set1_pd(B[(kkk + k + 0) + (jjj + j)*ldb]);
            bvec2 = _mm256_set1_pd(B[(kkk + k + 1) + (jjj + j)*ldb]);

            avec1 = _mm256_loadu_pd((__m256d*) (A + i + (kkk+k)*lda + 0));
            avec2 = _mm256_loadu_pd((__m256d*) (A + i + (kkk+k)*lda + 4));
            cvec1 = _mm256_fmadd_pd(avec1, bvec1, cvec1);
            cvec2 = _mm256_fmadd_pd(avec2, bvec1, cvec2);

            avec3 = _mm256_loadu_pd((__m256d*) (A + i + (kkk+k)*lda + 8));
            avec4 = _mm256_loadu_pd((__m256d*) (A + i + (kkk+k)*lda + 12));
            cvec3 = _mm256_fmadd_pd(avec3, bvec1, cvec3);
            cvec4 = _mm256_fmadd_pd(avec4, bvec1, cvec4);

            avec5 = _mm256_loadu_pd((__m256d*) (A + i + ((kkk+k)+1)*lda + 0));
            avec6 = _mm256_loadu_pd((__m256d*) (A + i + ((kkk+k)+1)*lda + 4));
            cvec5 = _mm256_fmadd_pd(avec5, bvec2, cvec5);
            cvec6 = _mm256_fmadd_pd(avec6, bvec2, cvec6);

            avec7 = _mm256_loadu_pd((__m256d*) (A + i + ((kkk+k)+1)*lda + 8));
            avec8 = _mm256_loadu_pd((__m256d*) (A + i + ((kkk+k)+1)*lda + 12));
            cvec7 = _mm256_fmadd_pd(avec7, bvec2, cvec7);
            cvec8 = _mm256_fmadd_pd(avec8, bvec2, cvec8);
        }
        // Write back to C
        _mm256_storeu_pd((__m256d*) (C + i + (jjj+j)*ldc + 0), _mm256_add_pd(cvec1, cvec5));
        _mm256_storeu_pd((__m256d*) (C + i + (jjj+j)*ldc + 8), _mm256_add_pd(cvec3, cvec7));
        _mm256_storeu_pd((__m256d*) (C + i + (jjj+j)*ldc + 4), _mm256_add_pd(cvec2, cvec6));
        _mm256_storeu_pd((__m256d*) (C + i + (jjj+j)*ldc + 12), _mm256_add_pd(cvec4, cvec8));
    }
}

Для контекста, ядро ​​вызывается внутри следующей функции,

void square_dgemm (const int n, const double* A, const double* B, double *restrict C){
    for (int j = 0; j < n; j += NU){
        for (int k = 0; k < n; k += KU){
            for (int i = 0; i < n; i += MU){
                if (((i+MU) < n) && ((j+NU) < n) && ((k+KU) < n)){
                    // Handle core of the GEMM divisible by chunks.
                    // The remainder of the branches are handled by
                    // less optimal sections. Over 90% of runtime is
                    // spent in tiny_dgemm_ijk.
                    tiny_dgemm_ijk(A, B, C, i, j, k, n, n, n);
                } else if (((i+MU) < n) && ((k+KU) < n)){
                   int J = MINIMUM((j+NU), n);
                   for (int jjj = j; (jjj < J); jjj++){
                       for (int kkk = 0; (kkk < KU/2); kkk++){
                           const double bij1 = B[kkk+k+jjj*n];
                           const double bij2 = B[kkk+k+(KU/2)+jjj*n];
                           for (int iii = 0; (iii < MU/4); iii++){
                               C[iii+i+0+(jjj)*n] += A[iii+i+0+(kkk+k)*n] * bij1;
                               C[iii+i+(MU/4)+0+(jjj)*n] += A[iii+i+(MU/4)+0+(kkk+k)*n] * bij1;
                               C[iii+i+(MU/2)+0+(jjj)*n] += A[iii+i+(MU/2)+0+(kkk+k)*n] * bij1;
                               C[iii+i+(3*(MU/4))+0+(jjj)*n] += A[iii+i+(3*(MU/4))+0+(kkk+k)*n] * bij1;

                               C[iii+i+0+(jjj)*n] += A[iii+i+0+(kkk+k+(KU/2))*n] * bij2;
                               C[iii+i+(MU/4)+0+(jjj)*n] += A[iii+i+(MU/4)+0+(kkk+k+(KU/2))*n] * bij2;
                               C[iii+i+(MU/2)+0+(jjj)*n] += A[iii+i+(MU/2)+0+(kkk+k+(KU/2))*n] * bij2;
                               C[iii+i+(3*(MU/4))+0+(jjj)*n] += A[iii+i+(3*(MU/4))+0+(kkk+k+(KU/2))*n] * bij2;
                           }
                       }
                   }
                } else if ((i+MU) < n) {
                   int J = MINIMUM((j+NU), n);
                   int K = MINIMUM((k+KU), n);
                   for (int jjj = j; (jjj < J); jjj++){
                       for (int kkk = k; (kkk < K); kkk++){
                           const double bij1 = B[kkk+jjj*n];
                           for (int iii = 0; (iii < MU/4); iii++){
                               C[iii+i+0+(jjj)*n] += A[iii+i+0+(kkk)*n] * bij1;
                               C[iii+i+(MU/4)+0+(jjj)*n] += A[iii+i+(MU/4)+0+(kkk)*n] * bij1;
                               C[iii+i+(MU/2)+0+(jjj)*n] += A[iii+i+(MU/2)+0+(kkk)*n] * bij1;
                               C[iii+i+(3*(MU/4))+0+(jjj)*n] += A[iii+i+(3*(MU/4))+0+(kkk)*n] * bij1;
                           }
                       }
                   }
                } else {
                    int J = MINIMUM((j+NU), n);
                    int K = MINIMUM((k+KU), n);
                    for (int jjj = j; (jjj < J); jjj++){
                        for (int kkk = k; (kkk < K); kkk++){
                            const double bij = B[kkk+jjj*n];
                            for (int iii = i; (iii < n); iii++){
                                C[iii+jjj*n] += A[iii+kkk*n] * bij;
                            }
                        }
                    }
               }
            }
        }
    }
}

Я не являюсь экспертом ни в низкоуровневом программировании, ни в оптимизации микроархивов, но я уверен, что для большинства вычислений данные находятся в кеше L1D. Есть ли очевидные причины высокой задержки L1D? Я компилирую с использованием компилятора Intel C и перепробовал несколько комбинаций флагов оптимизации с очень похожими результатами.

Буду очень признателен за любые рекомендации.

ОБНОВЛЕНИЕ

На всякий случай, если кто-нибудь захочет взглянуть на вывод сборки I CC, я извлек соответствующие строки (по крайней мере, я могу сказать).

        movslq    %r14d, %r14                                   #51.49
        xorb      %r9b, %r9b                                    #38.5
        movl      %esi, 1896(%rsp)                              #[spill]
        movq      264(%rsp), %rsi                               #[spill]
        movq      248(%rsp), %rcx                               #[spill]
        vmovupd   64(%rsi,%r14,8), %ymm15                       #56.49
        vmovupd   32(%rsi,%r14,8), %ymm12                       #52.49
        vmovupd   %ymm15, 704(%rsp)                             #56.49[spill]
        vmovupd   96(%rsi,%r14,8), %ymm15                       #57.49
        vmovupd   %ymm12, 1120(%rsp)                            #38.5[spill]
        vmovupd   %ymm15, 672(%rsp)                             #57.49[spill]
        vmovupd   (%rcx,%r14,8), %ymm15                         #61.49
        vmovupd   %ymm15, 640(%rsp)                             #61.49[spill]
        vmovupd   32(%rcx,%r14,8), %ymm15                       #62.49
        movl      %edi, 2056(%rsp)                              #[spill]
        movq      216(%rsp), %r12                               #[spill]
        movq      224(%rsp), %r11                               #[spill]
        movq      240(%rsp), %rdi                               #[spill]
        vmovupd   %ymm15, 608(%rsp)                             #62.49[spill]
        vmovupd   96(%rdi,%r14,8), %ymm13                       #67.49
        vmovupd   32(%r11,%r14,8), %ymm14                       #62.49
        vmovupd   64(%r11,%r14,8), %ymm7                        #66.49
        vmovupd   96(%r11,%r14,8), %ymm8                        #67.49
        vmovupd   (%r12,%r14,8), %ymm9                          #51.49
        vmovupd   96(%r12,%r14,8), %ymm10                       #57.49
        vmovupd   64(%rcx,%r14,8), %ymm15                       #66.49
        vmovupd   (%rdi,%r14,8), %ymm2                          #61.49
        vmovupd   32(%rdi,%r14,8), %ymm3                        #62.49
        vmovupd   64(%rdi,%r14,8), %ymm4                        #66.49
        vmovupd   %ymm13, 544(%rsp)                             #67.49[spill]
        vmovupd   %ymm14, 736(%rsp)                             #62.49[spill]
        vmovupd   %ymm7, 768(%rsp)                              #66.49[spill]
        vmovupd   %ymm8, 800(%rsp)                              #67.49[spill]
        vmovupd   %ymm9, 832(%rsp)                              #51.49[spill]
        vmovupd   %ymm10, 864(%rsp)                             #57.49[spill]
        vmovupd   %ymm15, 576(%rsp)                             #66.49[spill]
        vmovupd   32(%r12,%r14,8), %ymm14                       #52.49
        vmovupd   64(%r12,%r14,8), %ymm13                       #56.49
        vmovupd   (%rsi,%r14,8), %ymm7                          #51.49
        vmovupd   %ymm2, 448(%rsp)                              #61.49[spill]
        vmovupd   96(%rcx,%r14,8), %ymm15                       #67.49
        vmovupd   %ymm3, 480(%rsp)                              #62.49[spill]
        vmovupd   %ymm4, 512(%rsp)                              #66.49[spill]
        vmovupd   %ymm7, 1088(%rsp)                             #38.5[spill]
        vmovupd   %ymm13, 928(%rsp)                             #38.5[spill]
        vmovupd   %ymm15, 352(%rsp)                             #67.49[spill]
        vmovupd   %ymm14, 896(%rsp)                             #38.5[spill]
        movl      %ebx, 344(%rsp)                               #[spill]
        movq      256(%rsp), %rbx                               #[spill]
        movq      208(%rsp), %r13                               #[spill]
        vmovupd   64(%rbx,%r14,8), %ymm0                        #56.49
        vmovupd   96(%rbx,%r14,8), %ymm1                        #57.49
        vmovupd   (%r13,%r14,8), %ymm11                         #61.49
        vmovupd   32(%r13,%r14,8), %ymm10                       #62.49
        vmovupd   64(%r13,%r14,8), %ymm9                        #66.49
        vmovupd   96(%r13,%r14,8), %ymm8                        #67.49
        vmovupd   (%rbx,%r14,8), %ymm6                          #51.49
        vmovupd   32(%rbx,%r14,8), %ymm5                        #52.49
        vmovupd   %ymm0, 384(%rsp)                              #56.49[spill]
        vmovupd   %ymm1, 416(%rsp)                              #57.49[spill]
        vmovupd   (%r11,%r14,8), %ymm0                          #61.49
        vmovupd   %ymm8, 1056(%rsp)                             #38.5[spill]
        vmovupd   %ymm9, 1024(%rsp)                             #38.5[spill]
        vmovupd   %ymm10, 992(%rsp)                             #38.5[spill]
        vmovupd   %ymm11, 960(%rsp)                             #38.5[spill]
        movq      232(%rsp), %r15                               #[spill]
        movq      1976(%rsp), %r8                               #[spill]
        movl      %edx, 1888(%rsp)                              #[spill]
        movl      %r10d, 1200(%rsp)                             #[spill]
        lea       (%r8,%r14,8), %r10                            #39.45
        movl      %eax, 1208(%rsp)                              #[spill]
        xorl      %eax, %eax                                    #38.5
        movl      272(%rsp), %edx                               #[spill]
        movl      344(%rsp), %ebx                               #38.5[spill]
        movl      1896(%rsp), %esi                              #38.5[spill]
        movl      1936(%rsp), %r12d                             #38.5[spill]
        movl      2024(%rsp), %r8d                              #38.5[spill]
        movl      1904(%rsp), %ecx                              #38.5[spill]
        movl      2056(%rsp), %edi                              #38.5[spill]
        movq      1968(%rsp), %r13                              #38.5[spill]
        vmovupd   (%r15,%r14,8), %ymm4                          #51.49
        vmovupd   32(%r15,%r14,8), %ymm3                        #52.49
        vmovupd   64(%r15,%r14,8), %ymm2                        #56.49
        vmovupd   96(%r15,%r14,8), %ymm1                        #57.49
                                # LOE r10 r13 eax edx ecx ebx esi edi r8d r12d r14d r9b ymm0 ymm1 ymm2 ymm3 ymm4 ymm5 ymm6
L_B1.15:                        # Preds L_B1.15 L_B1.14
                                # Execution count [1.10e+02]
        vmovupd   416(%rsp), %ymm8                              #59.21[spill]
        lea       (%r12,%rax), %r11d                            #39.61
        movslq    %r11d, %r11                                   #39.61
        lea       (%rdx,%rax), %r15d                            #48.64
        movslq    %r15d, %r15                                   #48.64
        incb      %r9b                                          #38.5
        addl      %edi, %eax                                    #38.5
        vmovupd   64(%r10,%r11,8), %ymm12                       #45.45
        vmovupd   (%r10,%r11,8), %ymm14                         #39.45
        vmovupd   32(%r10,%r11,8), %ymm13                       #44.45
        vbroadcastsd (%r13,%r15,8), %ymm15                      #48.21
        vbroadcastsd 8(%r13,%r15,8), %ymm7                      #49.21
        vfmadd231pd 384(%rsp), %ymm15, %ymm12                   #58.21[spill]
        vfmadd231pd %ymm15, %ymm6, %ymm14                       #53.21
        vfmadd231pd %ymm15, %ymm5, %ymm13                       #54.21
        vfmadd213pd 96(%r10,%r11,8), %ymm8, %ymm15              #59.21
        vmulpd    448(%rsp), %ymm7, %ymm9                       #63.21[spill]
        vmulpd    480(%rsp), %ymm7, %ymm10                      #64.21[spill]
        vmulpd    512(%rsp), %ymm7, %ymm11                      #68.21[spill]
        vmulpd    544(%rsp), %ymm7, %ymm8                       #69.21[spill]
        vbroadcastsd 16(%r13,%r15,8), %ymm7                     #48.21
        vfmadd231pd %ymm4, %ymm7, %ymm14                        #53.21
        vfmadd231pd %ymm3, %ymm7, %ymm13                        #54.21
        vfmadd231pd %ymm2, %ymm7, %ymm12                        #58.21
        vfmadd231pd %ymm1, %ymm7, %ymm15                        #59.21
        vbroadcastsd 24(%r13,%r15,8), %ymm7                     #49.21
        vfmadd231pd 736(%rsp), %ymm7, %ymm10                    #64.21[spill]
        vfmadd231pd 768(%rsp), %ymm7, %ymm11                    #68.21[spill]
        vfmadd231pd 800(%rsp), %ymm7, %ymm8                     #69.21[spill]
        vfmadd231pd %ymm0, %ymm7, %ymm9                         #63.21
        vbroadcastsd 32(%r13,%r15,8), %ymm7                     #48.21
        vfmadd231pd 832(%rsp), %ymm7, %ymm14                    #53.21[spill]
        vfmadd231pd 896(%rsp), %ymm7, %ymm13                    #54.21[spill]
        vfmadd231pd 928(%rsp), %ymm7, %ymm12                    #58.21[spill]
        vfmadd231pd 864(%rsp), %ymm7, %ymm15                    #59.21[spill]
        vbroadcastsd 40(%r13,%r15,8), %ymm7                     #49.21
        vfmadd231pd 960(%rsp), %ymm7, %ymm9                     #63.21[spill]
        vfmadd231pd 992(%rsp), %ymm7, %ymm10                    #64.21[spill]
        vfmadd231pd 1024(%rsp), %ymm7, %ymm11                   #68.21[spill]
        vfmadd231pd 1056(%rsp), %ymm7, %ymm8                    #69.21[spill]
        vbroadcastsd 48(%r13,%r15,8), %ymm7                     #48.21
        vfmadd231pd 1088(%rsp), %ymm7, %ymm14                   #53.21[spill]
        vfmadd231pd 1120(%rsp), %ymm7, %ymm13                   #54.21[spill]
        vfmadd231pd 704(%rsp), %ymm7, %ymm12                    #58.21[spill]
        vfmadd231pd 672(%rsp), %ymm7, %ymm15                    #59.21[spill]
        vbroadcastsd 56(%r13,%r15,8), %ymm7                     #49.21
        vfmadd231pd 640(%rsp), %ymm7, %ymm9                     #63.21[spill]
        vfmadd231pd 608(%rsp), %ymm7, %ymm10                    #64.21[spill]
        vfmadd231pd 576(%rsp), %ymm7, %ymm11                    #68.21[spill]
        vfmadd231pd 352(%rsp), %ymm7, %ymm8                     #69.21[spill]
        vaddpd    %ymm9, %ymm14, %ymm9                          #72.64
        vaddpd    %ymm10, %ymm13, %ymm10                        #74.64
        vaddpd    %ymm11, %ymm12, %ymm11                        #73.64
        vaddpd    %ymm8, %ymm15, %ymm12                         #75.65
        vmovupd   %ymm9, (%r10,%r11,8)                          #72.38
        vmovupd   %ymm10, 32(%r10,%r11,8)                       #74.38
        vmovupd   %ymm11, 64(%r10,%r11,8)                       #73.38
        vmovupd   %ymm12, 96(%r10,%r11,8)                       #75.38
        cmpb      $16, %r9b                                     #38.5
        jb        L_B1.15       # Prob 93%                      #38.5
...