Я пытаюсь оптимизировать несколько ядер для AVX2, чтобы углубить свое понимание оптимизации микро-архитектурных программ. Одно из ядер, которое я пишу, - это матричное умножение. Я профилировал выполнение с Intel VTune и обнаружил, что код привязан к L1D Cache. Я провел некоторое исследование о том, что это значит, и наткнулся на эту статью от Intel, где изложены некоторые возможные причины. Прочитав различные возможности, я решил некоторые проблемы, добавив аккумулятор для предотвращения более длинной цепочки зависимостей. Я заглянул в сгенерированную сборку, но я не слишком разбираюсь в коде сборки.
Моя реализация выполняет примерно на 50% производительность DGEMM MKL для тех же входов. Это совпадает с VTune, который указывает на 50% производительности остановок конвейера памяти.
Код тратит более 90% времени в следующем ядре:
#include <xmmintrin.h> //AVX2
#define MU 16
#define NU 16
#define KU 8
#define MINIMUM(a,b) (((a)<(b))?(a):(b))
static inline void tiny_dgemm_ijk(const double* A, const double* B, double *restrict C, const int i, const int j, const int k, const int lda, const int ldb, const int ldc){
__m256d cvec1, cvec2, cvec3, cvec4, cvec5, cvec6, cvec7, cvec8;
__m256d cvec9, cvec10, cvec11, cvec12, cvec13, cvec14, cvec15, cvec16;
__m256d avec1, avec2, avec3, avec4, avec5, avec6, avec7, avec8;
__m256d bvec1, bvec2;
// Iterate over a chunk of NU columns of C.
for (int jjj = 0; (jjj < NU); jjj++){
// Load a chunk of C of dimensions 16 by 1
cvec1 = _mm256_loadu_pd((__m256d*) (C + i + (jjj+j)*ldc + 0));
cvec2 = _mm256_loadu_pd((__m256d*) (C + i + (jjj+j)*ldc + 4));
cvec3 = _mm256_loadu_pd((__m256d*) (C + i + (jjj+j)*ldc + 8));
cvec4 = _mm256_loadu_pd((__m256d*) (C + i + (jjj+j)*ldc + 12));
cvec5 = _mm256_set1_pd(0.0);
cvec6 = _mm256_set1_pd(0.0);
cvec7 = _mm256_set1_pd(0.0);
cvec8 = _mm256_set1_pd(0.0);
// Iterate over KU columns of A and rows of B
// Perform FMA operation over all 16 rows of A corresponding to the 16 rows of C.
for (int kkk = 0; (kkk < KU); kkk+=2){
bvec1 = _mm256_set1_pd(B[(kkk + k + 0) + (jjj + j)*ldb]);
bvec2 = _mm256_set1_pd(B[(kkk + k + 1) + (jjj + j)*ldb]);
avec1 = _mm256_loadu_pd((__m256d*) (A + i + (kkk+k)*lda + 0));
avec2 = _mm256_loadu_pd((__m256d*) (A + i + (kkk+k)*lda + 4));
cvec1 = _mm256_fmadd_pd(avec1, bvec1, cvec1);
cvec2 = _mm256_fmadd_pd(avec2, bvec1, cvec2);
avec3 = _mm256_loadu_pd((__m256d*) (A + i + (kkk+k)*lda + 8));
avec4 = _mm256_loadu_pd((__m256d*) (A + i + (kkk+k)*lda + 12));
cvec3 = _mm256_fmadd_pd(avec3, bvec1, cvec3);
cvec4 = _mm256_fmadd_pd(avec4, bvec1, cvec4);
avec5 = _mm256_loadu_pd((__m256d*) (A + i + ((kkk+k)+1)*lda + 0));
avec6 = _mm256_loadu_pd((__m256d*) (A + i + ((kkk+k)+1)*lda + 4));
cvec5 = _mm256_fmadd_pd(avec5, bvec2, cvec5);
cvec6 = _mm256_fmadd_pd(avec6, bvec2, cvec6);
avec7 = _mm256_loadu_pd((__m256d*) (A + i + ((kkk+k)+1)*lda + 8));
avec8 = _mm256_loadu_pd((__m256d*) (A + i + ((kkk+k)+1)*lda + 12));
cvec7 = _mm256_fmadd_pd(avec7, bvec2, cvec7);
cvec8 = _mm256_fmadd_pd(avec8, bvec2, cvec8);
}
// Write back to C
_mm256_storeu_pd((__m256d*) (C + i + (jjj+j)*ldc + 0), _mm256_add_pd(cvec1, cvec5));
_mm256_storeu_pd((__m256d*) (C + i + (jjj+j)*ldc + 8), _mm256_add_pd(cvec3, cvec7));
_mm256_storeu_pd((__m256d*) (C + i + (jjj+j)*ldc + 4), _mm256_add_pd(cvec2, cvec6));
_mm256_storeu_pd((__m256d*) (C + i + (jjj+j)*ldc + 12), _mm256_add_pd(cvec4, cvec8));
}
}
Для контекста, ядро вызывается внутри следующей функции,
void square_dgemm (const int n, const double* A, const double* B, double *restrict C){
for (int j = 0; j < n; j += NU){
for (int k = 0; k < n; k += KU){
for (int i = 0; i < n; i += MU){
if (((i+MU) < n) && ((j+NU) < n) && ((k+KU) < n)){
// Handle core of the GEMM divisible by chunks.
// The remainder of the branches are handled by
// less optimal sections. Over 90% of runtime is
// spent in tiny_dgemm_ijk.
tiny_dgemm_ijk(A, B, C, i, j, k, n, n, n);
} else if (((i+MU) < n) && ((k+KU) < n)){
int J = MINIMUM((j+NU), n);
for (int jjj = j; (jjj < J); jjj++){
for (int kkk = 0; (kkk < KU/2); kkk++){
const double bij1 = B[kkk+k+jjj*n];
const double bij2 = B[kkk+k+(KU/2)+jjj*n];
for (int iii = 0; (iii < MU/4); iii++){
C[iii+i+0+(jjj)*n] += A[iii+i+0+(kkk+k)*n] * bij1;
C[iii+i+(MU/4)+0+(jjj)*n] += A[iii+i+(MU/4)+0+(kkk+k)*n] * bij1;
C[iii+i+(MU/2)+0+(jjj)*n] += A[iii+i+(MU/2)+0+(kkk+k)*n] * bij1;
C[iii+i+(3*(MU/4))+0+(jjj)*n] += A[iii+i+(3*(MU/4))+0+(kkk+k)*n] * bij1;
C[iii+i+0+(jjj)*n] += A[iii+i+0+(kkk+k+(KU/2))*n] * bij2;
C[iii+i+(MU/4)+0+(jjj)*n] += A[iii+i+(MU/4)+0+(kkk+k+(KU/2))*n] * bij2;
C[iii+i+(MU/2)+0+(jjj)*n] += A[iii+i+(MU/2)+0+(kkk+k+(KU/2))*n] * bij2;
C[iii+i+(3*(MU/4))+0+(jjj)*n] += A[iii+i+(3*(MU/4))+0+(kkk+k+(KU/2))*n] * bij2;
}
}
}
} else if ((i+MU) < n) {
int J = MINIMUM((j+NU), n);
int K = MINIMUM((k+KU), n);
for (int jjj = j; (jjj < J); jjj++){
for (int kkk = k; (kkk < K); kkk++){
const double bij1 = B[kkk+jjj*n];
for (int iii = 0; (iii < MU/4); iii++){
C[iii+i+0+(jjj)*n] += A[iii+i+0+(kkk)*n] * bij1;
C[iii+i+(MU/4)+0+(jjj)*n] += A[iii+i+(MU/4)+0+(kkk)*n] * bij1;
C[iii+i+(MU/2)+0+(jjj)*n] += A[iii+i+(MU/2)+0+(kkk)*n] * bij1;
C[iii+i+(3*(MU/4))+0+(jjj)*n] += A[iii+i+(3*(MU/4))+0+(kkk)*n] * bij1;
}
}
}
} else {
int J = MINIMUM((j+NU), n);
int K = MINIMUM((k+KU), n);
for (int jjj = j; (jjj < J); jjj++){
for (int kkk = k; (kkk < K); kkk++){
const double bij = B[kkk+jjj*n];
for (int iii = i; (iii < n); iii++){
C[iii+jjj*n] += A[iii+kkk*n] * bij;
}
}
}
}
}
}
}
}
Я не являюсь экспертом ни в низкоуровневом программировании, ни в оптимизации микроархивов, но я уверен, что для большинства вычислений данные находятся в кеше L1D. Есть ли очевидные причины высокой задержки L1D? Я компилирую с использованием компилятора Intel C и перепробовал несколько комбинаций флагов оптимизации с очень похожими результатами.
Буду очень признателен за любые рекомендации.
ОБНОВЛЕНИЕ
На всякий случай, если кто-нибудь захочет взглянуть на вывод сборки I CC, я извлек соответствующие строки (по крайней мере, я могу сказать).
movslq %r14d, %r14 #51.49
xorb %r9b, %r9b #38.5
movl %esi, 1896(%rsp) #[spill]
movq 264(%rsp), %rsi #[spill]
movq 248(%rsp), %rcx #[spill]
vmovupd 64(%rsi,%r14,8), %ymm15 #56.49
vmovupd 32(%rsi,%r14,8), %ymm12 #52.49
vmovupd %ymm15, 704(%rsp) #56.49[spill]
vmovupd 96(%rsi,%r14,8), %ymm15 #57.49
vmovupd %ymm12, 1120(%rsp) #38.5[spill]
vmovupd %ymm15, 672(%rsp) #57.49[spill]
vmovupd (%rcx,%r14,8), %ymm15 #61.49
vmovupd %ymm15, 640(%rsp) #61.49[spill]
vmovupd 32(%rcx,%r14,8), %ymm15 #62.49
movl %edi, 2056(%rsp) #[spill]
movq 216(%rsp), %r12 #[spill]
movq 224(%rsp), %r11 #[spill]
movq 240(%rsp), %rdi #[spill]
vmovupd %ymm15, 608(%rsp) #62.49[spill]
vmovupd 96(%rdi,%r14,8), %ymm13 #67.49
vmovupd 32(%r11,%r14,8), %ymm14 #62.49
vmovupd 64(%r11,%r14,8), %ymm7 #66.49
vmovupd 96(%r11,%r14,8), %ymm8 #67.49
vmovupd (%r12,%r14,8), %ymm9 #51.49
vmovupd 96(%r12,%r14,8), %ymm10 #57.49
vmovupd 64(%rcx,%r14,8), %ymm15 #66.49
vmovupd (%rdi,%r14,8), %ymm2 #61.49
vmovupd 32(%rdi,%r14,8), %ymm3 #62.49
vmovupd 64(%rdi,%r14,8), %ymm4 #66.49
vmovupd %ymm13, 544(%rsp) #67.49[spill]
vmovupd %ymm14, 736(%rsp) #62.49[spill]
vmovupd %ymm7, 768(%rsp) #66.49[spill]
vmovupd %ymm8, 800(%rsp) #67.49[spill]
vmovupd %ymm9, 832(%rsp) #51.49[spill]
vmovupd %ymm10, 864(%rsp) #57.49[spill]
vmovupd %ymm15, 576(%rsp) #66.49[spill]
vmovupd 32(%r12,%r14,8), %ymm14 #52.49
vmovupd 64(%r12,%r14,8), %ymm13 #56.49
vmovupd (%rsi,%r14,8), %ymm7 #51.49
vmovupd %ymm2, 448(%rsp) #61.49[spill]
vmovupd 96(%rcx,%r14,8), %ymm15 #67.49
vmovupd %ymm3, 480(%rsp) #62.49[spill]
vmovupd %ymm4, 512(%rsp) #66.49[spill]
vmovupd %ymm7, 1088(%rsp) #38.5[spill]
vmovupd %ymm13, 928(%rsp) #38.5[spill]
vmovupd %ymm15, 352(%rsp) #67.49[spill]
vmovupd %ymm14, 896(%rsp) #38.5[spill]
movl %ebx, 344(%rsp) #[spill]
movq 256(%rsp), %rbx #[spill]
movq 208(%rsp), %r13 #[spill]
vmovupd 64(%rbx,%r14,8), %ymm0 #56.49
vmovupd 96(%rbx,%r14,8), %ymm1 #57.49
vmovupd (%r13,%r14,8), %ymm11 #61.49
vmovupd 32(%r13,%r14,8), %ymm10 #62.49
vmovupd 64(%r13,%r14,8), %ymm9 #66.49
vmovupd 96(%r13,%r14,8), %ymm8 #67.49
vmovupd (%rbx,%r14,8), %ymm6 #51.49
vmovupd 32(%rbx,%r14,8), %ymm5 #52.49
vmovupd %ymm0, 384(%rsp) #56.49[spill]
vmovupd %ymm1, 416(%rsp) #57.49[spill]
vmovupd (%r11,%r14,8), %ymm0 #61.49
vmovupd %ymm8, 1056(%rsp) #38.5[spill]
vmovupd %ymm9, 1024(%rsp) #38.5[spill]
vmovupd %ymm10, 992(%rsp) #38.5[spill]
vmovupd %ymm11, 960(%rsp) #38.5[spill]
movq 232(%rsp), %r15 #[spill]
movq 1976(%rsp), %r8 #[spill]
movl %edx, 1888(%rsp) #[spill]
movl %r10d, 1200(%rsp) #[spill]
lea (%r8,%r14,8), %r10 #39.45
movl %eax, 1208(%rsp) #[spill]
xorl %eax, %eax #38.5
movl 272(%rsp), %edx #[spill]
movl 344(%rsp), %ebx #38.5[spill]
movl 1896(%rsp), %esi #38.5[spill]
movl 1936(%rsp), %r12d #38.5[spill]
movl 2024(%rsp), %r8d #38.5[spill]
movl 1904(%rsp), %ecx #38.5[spill]
movl 2056(%rsp), %edi #38.5[spill]
movq 1968(%rsp), %r13 #38.5[spill]
vmovupd (%r15,%r14,8), %ymm4 #51.49
vmovupd 32(%r15,%r14,8), %ymm3 #52.49
vmovupd 64(%r15,%r14,8), %ymm2 #56.49
vmovupd 96(%r15,%r14,8), %ymm1 #57.49
# LOE r10 r13 eax edx ecx ebx esi edi r8d r12d r14d r9b ymm0 ymm1 ymm2 ymm3 ymm4 ymm5 ymm6
L_B1.15: # Preds L_B1.15 L_B1.14
# Execution count [1.10e+02]
vmovupd 416(%rsp), %ymm8 #59.21[spill]
lea (%r12,%rax), %r11d #39.61
movslq %r11d, %r11 #39.61
lea (%rdx,%rax), %r15d #48.64
movslq %r15d, %r15 #48.64
incb %r9b #38.5
addl %edi, %eax #38.5
vmovupd 64(%r10,%r11,8), %ymm12 #45.45
vmovupd (%r10,%r11,8), %ymm14 #39.45
vmovupd 32(%r10,%r11,8), %ymm13 #44.45
vbroadcastsd (%r13,%r15,8), %ymm15 #48.21
vbroadcastsd 8(%r13,%r15,8), %ymm7 #49.21
vfmadd231pd 384(%rsp), %ymm15, %ymm12 #58.21[spill]
vfmadd231pd %ymm15, %ymm6, %ymm14 #53.21
vfmadd231pd %ymm15, %ymm5, %ymm13 #54.21
vfmadd213pd 96(%r10,%r11,8), %ymm8, %ymm15 #59.21
vmulpd 448(%rsp), %ymm7, %ymm9 #63.21[spill]
vmulpd 480(%rsp), %ymm7, %ymm10 #64.21[spill]
vmulpd 512(%rsp), %ymm7, %ymm11 #68.21[spill]
vmulpd 544(%rsp), %ymm7, %ymm8 #69.21[spill]
vbroadcastsd 16(%r13,%r15,8), %ymm7 #48.21
vfmadd231pd %ymm4, %ymm7, %ymm14 #53.21
vfmadd231pd %ymm3, %ymm7, %ymm13 #54.21
vfmadd231pd %ymm2, %ymm7, %ymm12 #58.21
vfmadd231pd %ymm1, %ymm7, %ymm15 #59.21
vbroadcastsd 24(%r13,%r15,8), %ymm7 #49.21
vfmadd231pd 736(%rsp), %ymm7, %ymm10 #64.21[spill]
vfmadd231pd 768(%rsp), %ymm7, %ymm11 #68.21[spill]
vfmadd231pd 800(%rsp), %ymm7, %ymm8 #69.21[spill]
vfmadd231pd %ymm0, %ymm7, %ymm9 #63.21
vbroadcastsd 32(%r13,%r15,8), %ymm7 #48.21
vfmadd231pd 832(%rsp), %ymm7, %ymm14 #53.21[spill]
vfmadd231pd 896(%rsp), %ymm7, %ymm13 #54.21[spill]
vfmadd231pd 928(%rsp), %ymm7, %ymm12 #58.21[spill]
vfmadd231pd 864(%rsp), %ymm7, %ymm15 #59.21[spill]
vbroadcastsd 40(%r13,%r15,8), %ymm7 #49.21
vfmadd231pd 960(%rsp), %ymm7, %ymm9 #63.21[spill]
vfmadd231pd 992(%rsp), %ymm7, %ymm10 #64.21[spill]
vfmadd231pd 1024(%rsp), %ymm7, %ymm11 #68.21[spill]
vfmadd231pd 1056(%rsp), %ymm7, %ymm8 #69.21[spill]
vbroadcastsd 48(%r13,%r15,8), %ymm7 #48.21
vfmadd231pd 1088(%rsp), %ymm7, %ymm14 #53.21[spill]
vfmadd231pd 1120(%rsp), %ymm7, %ymm13 #54.21[spill]
vfmadd231pd 704(%rsp), %ymm7, %ymm12 #58.21[spill]
vfmadd231pd 672(%rsp), %ymm7, %ymm15 #59.21[spill]
vbroadcastsd 56(%r13,%r15,8), %ymm7 #49.21
vfmadd231pd 640(%rsp), %ymm7, %ymm9 #63.21[spill]
vfmadd231pd 608(%rsp), %ymm7, %ymm10 #64.21[spill]
vfmadd231pd 576(%rsp), %ymm7, %ymm11 #68.21[spill]
vfmadd231pd 352(%rsp), %ymm7, %ymm8 #69.21[spill]
vaddpd %ymm9, %ymm14, %ymm9 #72.64
vaddpd %ymm10, %ymm13, %ymm10 #74.64
vaddpd %ymm11, %ymm12, %ymm11 #73.64
vaddpd %ymm8, %ymm15, %ymm12 #75.65
vmovupd %ymm9, (%r10,%r11,8) #72.38
vmovupd %ymm10, 32(%r10,%r11,8) #74.38
vmovupd %ymm11, 64(%r10,%r11,8) #73.38
vmovupd %ymm12, 96(%r10,%r11,8) #75.38
cmpb $16, %r9b #38.5
jb L_B1.15 # Prob 93% #38.5