Как оптимизировать этот 8-битный позиционный счетчик с помощью сборки? - PullRequest
6 голосов
/ 04 августа 2020

Этот пост относится к Golang реализации сборки _mm_add_epi32 , где он добавляет парные элементы в два списка [8]int32 и возвращает обновленный первый.

Согласно pprof profile, я обнаружил, что передача [8]int32 стоит дорого, поэтому я думаю, что передача указателя списка намного дешевле, и результат bech подтвердил это. Вот версия go:

func __mm_add_epi32_inplace_purego(x, y *[8]int32) {
    (*x)[0] += (*y)[0]
    (*x)[1] += (*y)[1]
    (*x)[2] += (*y)[2]
    (*x)[3] += (*y)[3]
    (*x)[4] += (*y)[4]
    (*x)[5] += (*y)[5]
    (*x)[6] += (*y)[6]
    (*x)[7] += (*y)[7]
}

Эта функция вызывается на двух уровнях l oop.

Алгоритм вычисляет подсчет популяции позиций по массив байтов.

Спасибо совету от @fuz, я знаю, что написание всего алгоритма на ассемблере - лучший выбор и имеет смысл, но это вне моих возможностей, так как я никогда не изучаю программирование на ассемблере.

Однако должно быть легко оптимизировать внутренний l oop с помощью сборки:

counts := make([][8]int32, numRowBytes)

for i, b = range byteSlice {
    if b == 0 {                  // more than half of elements in byteSlice is 0.
        continue
    }
    expand = _expand_byte[b]
    __mm_add_epi32_inplace_purego(&counts[i], expand)
}

// expands a byte into its bits
var _expand_byte = [256]*[8]int32{
    &[8]int32{0, 0, 0, 0, 0, 0, 0, 0},
    &[8]int32{0, 0, 0, 0, 0, 0, 0, 1},
    &[8]int32{0, 0, 0, 0, 0, 0, 1, 0},
    &[8]int32{0, 0, 0, 0, 0, 0, 1, 1},
    &[8]int32{0, 0, 0, 0, 0, 1, 0, 0},
    ...
}

Можете ли вы помочь написать версию сборки __mm_add_epi32_inplace_purego (мне этого достаточно), или даже целиком л oop? Заранее спасибо.

1 Ответ

6 голосов
/ 04 августа 2020

Операция, которую вы хотите выполнить, называется подсчетом позиционного заполнения в байтах. Это хорошо известная операция, используемая в машинном обучении, и для решения этой проблемы были проведены некоторые исследования быстрых алгоритмов .

К сожалению, реализация этих алгоритмов довольно сложна. По этой причине я разработал собственный алгоритм, который намного проще реализовать, но дает примерно половину производительности по сравнению с другим другим методом. Однако при измерении 10 ГБ / с это все равно должно быть приличным улучшением по сравнению с тем, что было у вас раньше. возьмите скалярный счетчик населения, который затем добавляется к соответствующему счетчику. Это позволяет коротким цепочкам зависимостей и достижению согласованного IP C равного 3.

Обратите внимание, что по сравнению с вашим алгоритмом мой код меняет порядок битов. Вы можете изменить это, отредактировав, к каким элементам массива counts обращается ассемблерный код, если хотите. Однако, в интересах будущих читателей, я хотел бы оставить этот код с более распространенным соглашением, в котором младший бит считается битом 0.

Исходный код

Полный исходный код можно найти на github .

Алгоритм представлен в двух вариантах и ​​был протестирован на машине с процессором, обозначенным как «Intel (R) Xeon (R) W-2133 CPU @ 3.60GHz. ”

Позиционный счетчик популяции 32 байта за раз.

Счетчики хранятся в регистрах общего назначения для лучшей производительности. Память предварительно выбирается заранее для лучшего поведения потоковой передачи. Скалярный хвост обрабатывается с помощью очень простой комбинации SHRL / ADCL. Достигнута производительность до 11 ГБ / с.

#include "textflag.h"

// func PospopcntReg(counts *[8]int32, buf []byte)
TEXT ·PospopcntReg(SB),NOSPLIT,$0-32
    MOVQ counts+0(FP), DI
    MOVQ buf_base+8(FP), SI     // SI = &buf[0]
    MOVQ buf_len+16(FP), CX     // CX = len(buf)

    // load counts into register R8--R15
    MOVL 4*0(DI), R8
    MOVL 4*1(DI), R9
    MOVL 4*2(DI), R10
    MOVL 4*3(DI), R11
    MOVL 4*4(DI), R12
    MOVL 4*5(DI), R13
    MOVL 4*6(DI), R14
    MOVL 4*7(DI), R15

    SUBQ $32, CX            // pre-subtract 32 bit from CX
    JL scalar

vector: VMOVDQU (SI), Y0        // load 32 bytes from buf
    PREFETCHT0 384(SI)      // prefetch some data
    ADDQ $32, SI            // advance SI past them

    VPMOVMSKB Y0, AX        // move MSB of Y0 bytes to AX
    POPCNTL AX, AX          // count population of AX
    ADDL AX, R15            // add to counter
    VPADDD Y0, Y0, Y0       // shift Y0 left by one place

    VPMOVMSKB Y0, AX        // move MSB of Y0 bytes to AX
    POPCNTL AX, AX          // count population of AX
    ADDL AX, R14            // add to counter
    VPADDD Y0, Y0, Y0       // shift Y0 left by one place

    VPMOVMSKB Y0, AX        // move MSB of Y0 bytes to AX
    POPCNTL AX, AX          // count population of AX
    ADDL AX, R13            // add to counter
    VPADDD Y0, Y0, Y0       // shift Y0 left by one place

    VPMOVMSKB Y0, AX        // move MSB of Y0 bytes to AX
    POPCNTL AX, AX          // count population of AX
    ADDL AX, R12            // add to counter
    VPADDD Y0, Y0, Y0       // shift Y0 left by one place

    VPMOVMSKB Y0, AX        // move MSB of Y0 bytes to AX
    POPCNTL AX, AX          // count population of AX
    ADDL AX, R11            // add to counter
    VPADDD Y0, Y0, Y0       // shift Y0 left by one place

    VPMOVMSKB Y0, AX        // move MSB of Y0 bytes to AX
    POPCNTL AX, AX          // count population of AX
    ADDL AX, R10            // add to counter
    VPADDD Y0, Y0, Y0       // shift Y0 left by one place

    VPMOVMSKB Y0, AX        // move MSB of Y0 bytes to AX
    POPCNTL AX, AX          // count population of AX
    ADDL AX, R9         // add to counter
    VPADDD Y0, Y0, Y0       // shift Y0 left by one place

    VPMOVMSKB Y0, AX        // move MSB of Y0 bytes to AX
    POPCNTL AX, AX          // count population of AX
    ADDL AX, R8         // add to counter

    SUBQ $32, CX
    JGE vector          // repeat as long as bytes are left

scalar: ADDQ $32, CX            // undo last subtraction
    JE done             // if CX=0, there's nothing left

loop:   MOVBLZX (SI), AX        // load a byte from buf
    INCQ SI             // advance past it

    SHRL $1, AX         // CF=LSB, shift byte to the right
    ADCL $0, R8         // add CF to R8

    SHRL $1, AX
    ADCL $0, R9         // add CF to R9

    SHRL $1, AX
    ADCL $0, R10            // add CF to R10

    SHRL $1, AX
    ADCL $0, R11            // add CF to R11

    SHRL $1, AX
    ADCL $0, R12            // add CF to R12

    SHRL $1, AX
    ADCL $0, R13            // add CF to R13

    SHRL $1, AX
    ADCL $0, R14            // add CF to R14

    SHRL $1, AX
    ADCL $0, R15            // add CF to R15

    DECQ CX             // mark this byte as done
    JNE loop            // and proceed if any bytes are left

    // write R8--R15 back to counts
done:   MOVL R8, 4*0(DI)
    MOVL R9, 4*1(DI)
    MOVL R10, 4*2(DI)
    MOVL R11, 4*3(DI)
    MOVL R12, 4*4(DI)
    MOVL R13, 4*5(DI)
    MOVL R14, 4*6(DI)
    MOVL R15, 4*7(DI)

    VZEROUPPER          // restore SSE-compatibility
    RET

Позиционный счетчик популяции 96 байт за раз с CSA

Этот вариант выполняет все вышеупомянутые оптимизации, но уменьшает 96 байт до 64, используя заранее один шаг CSA. Как и ожидалось, это улучшает производительность примерно на 30% и достигает 16 ГБ / с.

#include "textflag.h"

// func PospopcntRegCSA(counts *[8]int32, buf []byte)
TEXT ·PospopcntRegCSA(SB),NOSPLIT,$0-32
    MOVQ counts+0(FP), DI
    MOVQ buf_base+8(FP), SI     // SI = &buf[0]
    MOVQ buf_len+16(FP), CX     // CX = len(buf)

    // load counts into register R8--R15
    MOVL 4*0(DI), R8
    MOVL 4*1(DI), R9
    MOVL 4*2(DI), R10
    MOVL 4*3(DI), R11
    MOVL 4*4(DI), R12
    MOVL 4*5(DI), R13
    MOVL 4*6(DI), R14
    MOVL 4*7(DI), R15

    SUBQ $96, CX            // pre-subtract 32 bit from CX
    JL scalar

vector: VMOVDQU (SI), Y0        // load 96 bytes from buf into Y0--Y2
    VMOVDQU 32(SI), Y1
    VMOVDQU 64(SI), Y2
    ADDQ $96, SI            // advance SI past them
    PREFETCHT0 320(SI)
    PREFETCHT0 384(SI)

    VPXOR Y0, Y1, Y3        // first adder: sum
    VPAND Y0, Y1, Y0        // first adder: carry out
    VPAND Y2, Y3, Y1        // second adder: carry out
    VPXOR Y2, Y3, Y2        // second adder: sum (full sum)
    VPOR Y0, Y1, Y0         // full adder: carry out

    VPMOVMSKB Y0, AX        // MSB of carry out bytes
    VPMOVMSKB Y2, DX        // MSB of sum bytes
    VPADDB Y0, Y0, Y0       // shift carry out bytes left
    VPADDB Y2, Y2, Y2       // shift sum bytes left
    POPCNTL AX, AX          // carry bytes population count
    POPCNTL DX, DX          // sum bytes population count
    LEAL (DX)(AX*2), AX     // sum popcount plus 2x carry popcount
    ADDL AX, R15

    VPMOVMSKB Y0, AX        // MSB of carry out bytes
    VPMOVMSKB Y2, DX        // MSB of sum bytes
    VPADDB Y0, Y0, Y0       // shift carry out bytes left
    VPADDB Y2, Y2, Y2       // shift sum bytes left
    POPCNTL AX, AX          // carry bytes population count
    POPCNTL DX, DX          // sum bytes population count
    LEAL (DX)(AX*2), AX     // sum popcount plus 2x carry popcount
    ADDL AX, R14

    VPMOVMSKB Y0, AX        // MSB of carry out bytes
    VPMOVMSKB Y2, DX        // MSB of sum bytes
    VPADDB Y0, Y0, Y0       // shift carry out bytes left
    VPADDB Y2, Y2, Y2       // shift sum bytes left
    POPCNTL AX, AX          // carry bytes population count
    POPCNTL DX, DX          // sum bytes population count
    LEAL (DX)(AX*2), AX     // sum popcount plus 2x carry popcount
    ADDL AX, R13

    VPMOVMSKB Y0, AX        // MSB of carry out bytes
    VPMOVMSKB Y2, DX        // MSB of sum bytes
    VPADDB Y0, Y0, Y0       // shift carry out bytes left
    VPADDB Y2, Y2, Y2       // shift sum bytes left
    POPCNTL AX, AX          // carry bytes population count
    POPCNTL DX, DX          // sum bytes population count
    LEAL (DX)(AX*2), AX     // sum popcount plus 2x carry popcount
    ADDL AX, R12

    VPMOVMSKB Y0, AX        // MSB of carry out bytes
    VPMOVMSKB Y2, DX        // MSB of sum bytes
    VPADDB Y0, Y0, Y0       // shift carry out bytes left
    VPADDB Y2, Y2, Y2       // shift sum bytes left
    POPCNTL AX, AX          // carry bytes population count
    POPCNTL DX, DX          // sum bytes population count
    LEAL (DX)(AX*2), AX     // sum popcount plus 2x carry popcount
    ADDL AX, R11

    VPMOVMSKB Y0, AX        // MSB of carry out bytes
    VPMOVMSKB Y2, DX        // MSB of sum bytes
    VPADDB Y0, Y0, Y0       // shift carry out bytes left
    VPADDB Y2, Y2, Y2       // shift sum bytes left
    POPCNTL AX, AX          // carry bytes population count
    POPCNTL DX, DX          // sum bytes population count
    LEAL (DX)(AX*2), AX     // sum popcount plus 2x carry popcount
    ADDL AX, R10

    VPMOVMSKB Y0, AX        // MSB of carry out bytes
    VPMOVMSKB Y2, DX        // MSB of sum bytes
    VPADDB Y0, Y0, Y0       // shift carry out bytes left
    VPADDB Y2, Y2, Y2       // shift sum bytes left
    POPCNTL AX, AX          // carry bytes population count
    POPCNTL DX, DX          // sum bytes population count
    LEAL (DX)(AX*2), AX     // sum popcount plus 2x carry popcount
    ADDL AX, R9

    VPMOVMSKB Y0, AX        // MSB of carry out bytes
    VPMOVMSKB Y2, DX        // MSB of sum bytes
    POPCNTL AX, AX          // carry bytes population count
    POPCNTL DX, DX          // sum bytes population count
    LEAL (DX)(AX*2), AX     // sum popcount plus 2x carry popcount
    ADDL AX, R8

    SUBQ $96, CX
    JGE vector          // repeat as long as bytes are left

scalar: ADDQ $96, CX            // undo last subtraction
    JE done             // if CX=0, there's nothing left

loop:   MOVBLZX (SI), AX        // load a byte from buf
    INCQ SI             // advance past it

    SHRL $1, AX         // is bit 0 set?
    ADCL $0, R8         // add it to R8

    SHRL $1, AX         // is bit 0 set?
    ADCL $0, R9         // add it to R9

    SHRL $1, AX         // is bit 0 set?
    ADCL $0, R10            // add it to R10

    SHRL $1, AX         // is bit 0 set?
    ADCL $0, R11            // add it to R11

    SHRL $1, AX         // is bit 0 set?
    ADCL $0, R12            // add it to R12

    SHRL $1, AX         // is bit 0 set?
    ADCL $0, R13            // add it to R13

    SHRL $1, AX         // is bit 0 set?
    ADCL $0, R14            // add it to R14

    SHRL $1, AX         // is bit 0 set?
    ADCL $0, R15            // add it to R15

    DECQ CX             // mark this byte as done
    JNE loop            // and proceed if any bytes are left

    // write R8--R15 back to counts
done:   MOVL R8, 4*0(DI)
    MOVL R9, 4*1(DI)
    MOVL R10, 4*2(DI)
    MOVL R11, 4*3(DI)
    MOVL R12, 4*4(DI)
    MOVL R13, 4*5(DI)
    MOVL R14, 4*6(DI)
    MOVL R15, 4*7(DI)

    VZEROUPPER          // restore SSE-compatibility
    RET

Тесты

Вот тесты для двух алгоритмов и простая эталонная реализация в чистом виде. Go. Полные тесты можно найти в репозитории github.

BenchmarkReference/10-12    12448764            80.9 ns/op   123.67 MB/s
BenchmarkReference/32-12     4357808           258 ns/op     124.25 MB/s
BenchmarkReference/1000-12            151173          7889 ns/op     126.76 MB/s
BenchmarkReference/2000-12             68959         15774 ns/op     126.79 MB/s
BenchmarkReference/4000-12             36481         31619 ns/op     126.51 MB/s
BenchmarkReference/10000-12            14804         78917 ns/op     126.72 MB/s
BenchmarkReference/100000-12            1540        789450 ns/op     126.67 MB/s
BenchmarkReference/10000000-12            14      77782267 ns/op     128.56 MB/s
BenchmarkReference/1000000000-12           1    7781360044 ns/op     128.51 MB/s
BenchmarkReg/10-12                  49255107            24.5 ns/op   407.42 MB/s
BenchmarkReg/32-12                  186935192            6.40 ns/op 4998.53 MB/s
BenchmarkReg/1000-12                 8778610           115 ns/op    8677.33 MB/s
BenchmarkReg/2000-12                 5358495           208 ns/op    9635.30 MB/s
BenchmarkReg/4000-12                 3385945           357 ns/op    11200.23 MB/s
BenchmarkReg/10000-12                1298670           901 ns/op    11099.24 MB/s
BenchmarkReg/100000-12                115629          8662 ns/op    11544.98 MB/s
BenchmarkReg/10000000-12                1270        916817 ns/op    10907.30 MB/s
BenchmarkReg/1000000000-12                12      93609392 ns/op    10682.69 MB/s
BenchmarkRegCSA/10-12               48337226            23.9 ns/op   417.92 MB/s
BenchmarkRegCSA/32-12               12843939            80.2 ns/op   398.86 MB/s
BenchmarkRegCSA/1000-12              7175629           150 ns/op    6655.70 MB/s
BenchmarkRegCSA/2000-12              3988408           295 ns/op    6776.20 MB/s
BenchmarkRegCSA/4000-12              3016693           382 ns/op    10467.41 MB/s
BenchmarkRegCSA/10000-12             1810195           642 ns/op    15575.65 MB/s
BenchmarkRegCSA/100000-12             191974          6229 ns/op    16053.40 MB/s
BenchmarkRegCSA/10000000-12             1622        698856 ns/op    14309.10 MB/s
BenchmarkRegCSA/1000000000-12             16      68540642 ns/op    14589.88 MB/s
...