Присвоение всем записям, чьи индексы суммируют с некоторым значением - PullRequest
0 голосов
/ 09 мая 2019

У меня есть массив X двоичных чисел и формы (2, 2, ..., 2), и я хотел бы присвоить значение 1 всем записям, индексы которых равны 0 по модулю 2, а значение 0 - остальным.

Например, если бы у нас было X.shape = (2, 2, 2), я бы хотел присвоить 1 для X[0, 0, 0], X[0, 1, 1], X[1, 0, 1], X[1, 1, 0] и 0 для других 4 записей.

Какой самый эффективный способ сделать это? Я предполагаю, что должен создать этот массив с типом данных np.bool, поэтому решение должно работать с учетом этого.

1 Ответ

1 голос
/ 10 мая 2019

Вот прямой и хитрый метод.Хитрый использует битовую упаковку и использует определенные повторяющиеся шаблоны.Для больших n это дает значительное ускорение (>50 @ n=19).

import functools as ft
import numpy as np

def direct(n):
    I = np.arange(2, dtype='u1')
    return ft.reduce(np.bitwise_xor, np.ix_(I[::-1], *(n-1)*(I,)))

def smartish(n):
    assert n >= 6
    b = np.empty(1<<(n-3), 'u1')
    b[[0, 3, 5, 6]] = 0b10010110
    b[[1, 2, 4, 7]] = 0b01101001
    i = b.view('u8')
    jp = 1
    for j in range(0, n-7, 2):
        i[3*jp:4*jp] = i[:jp]
        i[jp:3*jp].reshape(2, -1)[...] = 0xffff_ffff_ffff_ffff ^ i[:jp]
        jp *= 4
    if n & 1:
        i[jp:] = 0xffff_ffff_ffff_ffff ^ i[:jp]
    return np.unpackbits(b).reshape(n*(2,))

from timeit import timeit

assert np.all(smartish(19) == direct(19))

print(f"direct   {timeit(lambda: direct(19), number=100)*10:.3f} ms")
print(f"smartish {timeit(lambda: smartish(19), number=100)*10:.3f} ms")

Образец запуска на коробке 2^19:

direct   5.408 ms
smartish 0.079 ms

Обратите внимание, что эти возвращаемые значения uint8 массивы, например:

>>> direct(3)
array([[[1, 0],
        [0, 1]],

       [[0, 1],
        [1, 0]]], dtype=uint8)

Но их можно навести на bool практически при нулевой стоимости:

>>> direct(3).view('?')
array([[[ True, False],
        [False,  True]],

       [[False,  True],
        [ True, False]]])

Объяснитель:

прямой метод: Один простой способ проверить битовую четность - это xor битов вместе.Нам нужно сделать это «сокращающим» способом, то есть мы должны применить двоичную операцию xor к первым двум операндам, затем к результату и третьему операнду, затем к этому результату и четвертому операнду и так далее.Это то, что делает functools.reduce.

Кроме того, мы не хотим делать это только один раз, но в каждой точке 2^n сетки.numpy способ сделать это - открытые сетки.Они могут быть сгенерированы по осям 1D с использованием np.ix_ или в простых случаях с использованием np.ogrid.Обратите внимание, что мы перевернем самую первую ось, чтобы учесть тот факт, что нам нужен инвертированный паритет.

Умный метод.Мы делаем две основные оптимизации.1) xor - это побитовая операция, означающая, что она выполняет «64-сторонние параллельные вычисления» бесплатно, если мы упаковываем наши биты в 64-битную строку.2) Если мы сгладим гиперкуб 2 ^ n, то положение n в линейном расположении соответствует ячейке (бит1, бит2, бит3, ...) в гиперкубе, где бит1, бит2 и т. Д. Является двоичным представлением (с ведущими нулями)п.Теперь обратите внимание, что если мы вычислили четности позиций 0 .. 0b11..11 = 2 ^ k-1, то мы можем получить четности 2 ^ k..2 ^ (k + 1) -1, просто копируя и инвертируяуже вычисленные соотношения.Например, k = 2:

0b000, 0b001, 0b010, 0b011 would be what we have and
0b100, 0b101, 0b110, 0b111 would be what we need to compute
  ^      ^      ^      ^   

Поскольку эти две последовательности отличаются только отмеченным битом, ясно, что действительно их суммы из нескольких цифр отличаются на единицу, а четности инвертированы.

КакВ упражнении вы узнаете, что можно сказать аналогичным образом о следующих 2 ^ k записях и 2 ^ k записях после них.

...