Numba - как я могу извлечь значение из многомерного массива NumPy - PullRequest
0 голосов
/ 02 мая 2018

Я использую numba и numpy для написания функции, и в ходе своей функции я вычислю эти два элемента: idx, который представляет собой список координат (например, idx = [3,4,5]), и values, трехмерный массив numpy (например, values.shape даст (100, 100, 100)). Обратите внимание, что размер values может меняться произвольно.

Обычно, если я сделаю values[3,4,5], он вернет число, которое является значением в указанной координате. Однако, если я сделаю values[idx], я получу массив! Я знаю, что это работает: values[tuple(idx)], но это приведет к ошибке в numba:

TypingError: cannot determine Numba type of <class 'type'>

Я не могу сделать idx кортежем для начала, потому что idx создается в цикле for, который добавляет элементы к idx, который определен как пустой список перед циклом.

Есть ли простой способ извлечь значение из трехмерного массива с заданным списком, координаты которого находятся в каждом из измерений? Эту, казалось бы, легкую проблему невероятно трудно решить.

Это минимально воспроизводимая функция:

gridMat = (np.linspace(1, 3, 100), np.linspace(0,2,100), 
        np.linspace(5, 6,100))
mGrid = np.meshgrid(*gridMat, indexing = 'ij')
values = np.power(mGrid[0], 2) + mGrid[1] / 4.0 + 3.0 + mGrid[2] * 4

@jit(nopython=True)
def testFunction(values):
    idx = []
    N = 3
    for n in range(N):
        idx.append(n + 1)

    idx_res = tuple(idx)

    print(values[idx_res])

1 Ответ

0 голосов
/ 02 мая 2018

Поскольку никто не отвечает, я опубликую свой комментарий в качестве ответа:

Что-то не так с values[idx[0], idx[1], idx[2]]? Для больших idx это было бы непрактично, но кажется, что values.shape[0] фиксировано и равно 3, поэтому это базовое решение выглядит как способ для меня.

В противном случае, решение, которое приходит на ум, состоит в том, чтобы циклически проходить через dim и получать доступ к элементу плоского представления значений:

offset = 1
for d in range(values.ndim):
    offset *= idx[d] * values.strides[d]
element = values.flat[offset]

(также вам нужно разделить шаги на 8, если ваш dtype - float64) но это не очень красиво ... Вы знаете, как работают strides и flat?

...