Подсчет 1 бита (подсчет населения) для больших данных с использованием AVX-512 или AVX-2 - PullRequest
0 голосов
/ 29 апреля 2018

У меня длинный кусок памяти, скажем, 256 КиБ или больше. Я хочу подсчитать количество 1 битов во всем этом фрагменте, или другими словами: сложить значения «количества населения» для всех байтов.

Я знаю, что AVX-512 имеет инструкцию VPOPCNTDQ , которая подсчитывает количество 1 бит в каждом последующем 64 бита в 512-битном векторе, и у IIANM должна быть возможность выдавать один из них каждый цикл (если имеется соответствующий векторный регистр SIMD) - но у меня нет никакого опыта написания кода SIMD (я скорее специалист по GPU). Кроме того, я не уверен на 100% в поддержке компилятора для целей AVX-512.

На большинстве процессоров AVX-512 все же не поддерживается (полностью); но AVX-2 широко доступен. Я не смог найти менее 512-битную векторизованную инструкцию, похожую на VPOPCNTDQ, поэтому даже теоретически я не уверен, как быстро считать биты с процессорами, поддерживающими AVX-2; может быть, что-то подобное существует, и я просто как-то пропустил?

В любом случае, я был бы признателен за короткую функцию C / C ++ - либо с использованием некоторой библиотеки-обёртки, либо со встроенной сборкой - для каждого из двух наборов команд. Подпись

uint64_t count_bits(void* ptr, size_t size);

Примечания:

Ответы [ 2 ]

0 голосов
/ 29 апреля 2018

Функции большого массива Войцеха Мулы выглядят оптимально, за исключением скалярных циклов очистки. (Подробнее об основных циклах см. В ответе @ einpoklum).

LUT с 256 записями, которое вы используете только пару раз в конце, вероятно, будет отсутствовать в кэше и не будет оптимальным для более чем 1 байта, даже если кэш был горячим. Я полагаю, что все процессоры AVX2 имеют аппаратное обеспечение popcnt, и мы можем легко изолировать последние до 8 байтов, которые еще не были подсчитаны, чтобы настроить нас на один popcnt.

Как обычно с алгоритмами SIMD, он часто хорошо работает для загрузки на всю ширину, которая заканчивается на последнем байте буфера. Но в отличие от векторного регистра, сдвиги с переменным счетом полного целочисленного регистра дешевы (особенно с BMI2). Popcnt не волнует , где биты, поэтому мы можем просто использовать сдвиг вместо необходимости создавать маску AND или что-то еще.

// untested
// ptr points at the first byte that hasn't been counted yet
uint64_t final_bytes = reinterpret_cast<const uint64_t*>(end)[-1] >> (8*(end-ptr));
total += _mm_popcnt_u64( final_bytes );
// Careful, this could read outside a small buffer.

Или, что еще лучше, используйте более сложную логику, чтобы избежать пересечения страниц. Это позволяет избежать пересечения страниц для 6-байтового буфера в начале страницы, например.

0 голосов
/ 29 апреля 2018

AVX-2

@ HadiBreis 'комментирует ссылки на статью о быстром подсчете населения с помощью SSSE3, автор Wojciech Muła; статья ссылается на этот репозиторий GitHub ; а репозиторий имеет следующую реализацию AVX-2. Он основан на векторизованной инструкции поиска и использует таблицу поиска из 16 значений для количества битов в полубайтах.

#   include <immintrin.h>
#   include <x86intrin.h>

std::uint64_t popcnt_AVX2_lookup(const uint8_t* data, const size_t n) {

    size_t i = 0;

    const __m256i lookup = _mm256_setr_epi8(
        /* 0 */ 0, /* 1 */ 1, /* 2 */ 1, /* 3 */ 2,
        /* 4 */ 1, /* 5 */ 2, /* 6 */ 2, /* 7 */ 3,
        /* 8 */ 1, /* 9 */ 2, /* a */ 2, /* b */ 3,
        /* c */ 2, /* d */ 3, /* e */ 3, /* f */ 4,

        /* 0 */ 0, /* 1 */ 1, /* 2 */ 1, /* 3 */ 2,
        /* 4 */ 1, /* 5 */ 2, /* 6 */ 2, /* 7 */ 3,
        /* 8 */ 1, /* 9 */ 2, /* a */ 2, /* b */ 3,
        /* c */ 2, /* d */ 3, /* e */ 3, /* f */ 4
    );

    const __m256i low_mask = _mm256_set1_epi8(0x0f);

    __m256i acc = _mm256_setzero_si256();

#define ITER { \
        const __m256i vec = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(data + i)); \
        const __m256i lo  = _mm256_and_si256(vec, low_mask); \
        const __m256i hi  = _mm256_and_si256(_mm256_srli_epi16(vec, 4), low_mask); \
        const __m256i popcnt1 = _mm256_shuffle_epi8(lookup, lo); \
        const __m256i popcnt2 = _mm256_shuffle_epi8(lookup, hi); \
        local = _mm256_add_epi8(local, popcnt1); \
        local = _mm256_add_epi8(local, popcnt2); \
        i += 32; \
    }

    while (i + 8*32 <= n) {
        __m256i local = _mm256_setzero_si256();
        ITER ITER ITER ITER
        ITER ITER ITER ITER
        acc = _mm256_add_epi64(acc, _mm256_sad_epu8(local, _mm256_setzero_si256()));
    }

    __m256i local = _mm256_setzero_si256();

    while (i + 32 <= n) {
        ITER;
    }

    acc = _mm256_add_epi64(acc, _mm256_sad_epu8(local, _mm256_setzero_si256()));

#undef ITER

    uint64_t result = 0;

    result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 0));
    result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 1));
    result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 2));
    result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 3));

    for (/**/; i < n; i++) {
        result += lookup8bit[data[i]];
    }

    return result;
}

AVX-512

В этом же хранилище также реализована реализация AVX-512 на основе VPOPCNT:

#   include <immintrin.h>
#   include <x86intrin.h>

uint64_t avx512_vpopcnt(const uint8_t* data, const size_t size) {

    const size_t chunks = size / 64;

    uint8_t* ptr = const_cast<uint8_t*>(data);
    const uint8_t* end = ptr + size;

    // count using AVX512 registers
    __m512i accumulator = _mm512_setzero_si512();
    for (size_t i=0; i < chunks; i++, ptr += 64) {

        // Note: a short chain of dependencies, likely unrolling will be needed.
        const __m512i v = _mm512_loadu_si512((const __m512i*)ptr);
        const __m512i p = _mm512_popcnt_epi64(v);

        accumulator = _mm512_add_epi64(accumulator, p);
    }

    // horizontal sum of a register
    uint64_t tmp[8] __attribute__((aligned(64)));
    _mm512_store_si512((__m512i*)tmp, accumulator);

    uint64_t total = 0;
    for (size_t i=0; i < 8; i++) {
        total += tmp[i];
    }

    // popcount the tail
    while (ptr + 8 < end) {
        total += _mm_popcnt_u64(*reinterpret_cast<const uint64_t*>(ptr));
        ptr += 8;
    }

    while (ptr < end) {
        total += lookup8bit[*ptr++];
    }

    return total;
}

lookup8bit - это таблица поиска popcnt для байтов, а не битов, и определяется здесь . edit: Как отмечают комментаторы, использование 8-битной справочной таблицы в конце не очень хорошая идея, и ее можно улучшить.

...