Почему Numba искажает мой массив dtype при использовании numpy случайного перемешивания? - PullRequest
0 голосов
/ 03 мая 2020

Я пытаюсь оптимизировать очень простую функцию, которая просматривает каждую строку двумерного массива d-типа, присваивает некоторые значения его полям и в то же время случайным образом перетасовывает порядок элементов (рассмотрим каждый элемент как частицу, которая должна взаимодействовать с другими частицами, и я буду перетасовывать их в порядке шага за время шага моделирования). Как ни странно, когда я использую numba.njit, чтобы оптимизировать его, мой массив полностью запутался и стал другим. Вот простое воспроизведение моей проблемы:

import numpy as np
import numba as nb

D = np.dtype([('f_id','i2'),('p_id','i2'),('x','f8')])
A = np.zeros((3,6),dtype=D)
A['p_id'] = np.arange(6)
B = np.zeros((3,6),dtype=D)
B['p_id'] = np.arange(6)

def test1(x):
for i in range(3):
    y = x[i][:]
    y['f_id'][:] = i
    y['x'][:] = 1000 + i
    np.random.seed(2000)
    np.random.shuffle(y)
return x

@nb.njit
def test2(X):
for ii in range(3):
    Y = X[ii][:]
    Y['f_id'][:] = ii
    Y['x'][:] = 1000 + ii
    np.random.seed(2000)
    np.random.shuffle(Y)
return X

test1(A)
test2(B)

первая функция без njit работает нормально, как и ожидалось:

A
array([[(0, 1, 1000.), (0, 2, 1000.), (0, 4, 1000.), (0, 5, 1000.),
        (0, 3, 1000.), (0, 0, 1000.)],
       [(1, 1, 1001.), (1, 2, 1001.), (1, 4, 1001.), (1, 5, 1001.),
        (1, 3, 1001.), (1, 0, 1001.)],
       [(2, 1, 1002.), (2, 2, 1002.), (2, 4, 1002.), (2, 5, 1002.),
        (2, 3, 1002.), (2, 0, 1002.)]],
      dtype=[('f_id', '<i2'), ('p_id', '<i2'), ('x', '<f8')])

Но вторая функция каждый раз дает странные результаты:

B
array([[(0, 0, 1000.), (0, 0, 1000.), (0, 0, 1000.), (0, 0, 1000.),
        (0, 3, 1000.), (0, 0, 1000.)],
       [(1, 0, 1001.), (1, 0, 1001.), (1, 0, 1001.), (1, 0, 1001.),
        (1, 3, 1001.), (1, 0, 1001.)],
       [(2, 0, 1002.), (2, 0, 1002.), (2, 0, 1002.), (2, 0, 1002.),
        (2, 3, 1002.), (2, 0, 1002.)]],
      dtype=[('f_id', '<i2'), ('p_id', '<i2'), ('x', '<f8')])

Numba, очевидно, поддерживает случайное перемешивание: http://numba.pydata.org/numba-doc/0.22.1/reference/numpysupported.html.

Стоит отметить, что эта проблема не возникает с массивами numpy, отличными от массива dtype.

По-видимому, Numba не распознает numpy генератор случайных чисел, default_rng (), https://numpy.org/doc/1.18/reference/random/generator.html#numpy .random.default_rng .

Я попытался переместить случайное начальное число в и из функции с сопряжением не работает.

...