Индексация многомерного массива NumPa внутри класса jumba - PullRequest
0 голосов
/ 13 января 2019

Я пытаюсь вставить небольшой многомерный массив в больший внутри jumbc класса numba. Для маленького массива задаются конкретные позиции большего массива, определенного списком индексов.

Следующий MWE показывает проблему без numba - все работает как положено

import numpy as np

class NumbaClass(object):

    def __init__(self, n, m):
        self.A = np.zeros((n, m))

    # solution 1 using pure python
    def nonNumbaFunction1(self, idx, values):
        self.A[idx[:, None], idx] = values

    # solution 2 using pure python
    def nonNumbaFunction2(self, idx, values):
        self.A[np.ix_(idx, idx)] = values

if __name__ == "__main__":
    n = 6
    m = 8
    obj = NumbaClass(n, m)
    print(f'A =\n{obj.A}')

    idx = np.array([0, 2, 5])
    values = np.arange(len(idx)**2).reshape(len(idx), len(idx))
    print(f'values =\n{values}')

    obj.nonNumbaFunction1(idx, values)
    print(f'A =\n{obj.A}')

    obj.nonNumbaFunction2(idx, values)
    print(f'A =\n{obj.A}')

Обе функции nonNumbaFunction1 и nonNumbaFunction2 не работают внутри класса numba. Так что мое текущее решение выглядит так, что, на мой взгляд, не очень хорошо

import numpy as np

from numba import jitclass      
from numba import int64, float64
from collections import OrderedDict

specs = OrderedDict()
specs['A'] = float64[:, :]

@jitclass(specs)
class NumbaClass(object):

    def __init__(self, n, m):
        self.A = np.zeros((n, m))

    # solution for numba jitclass
    def numbaFunction(self, idx, values):
        for i in range(len(values)):
            idxi = idx[i]
            for j in range(len(values)):
                idxj = idx[j]
                self.A[idxi, idxj] = values[i, j]

if __name__ == "__main__":
    n = 6
    m = 8
    obj = NumbaClass(n, m)
    print(f'A =\n{obj.A}')

    idx = np.array([0, 2, 5])
    values = np.arange(len(idx)**2).reshape(len(idx), len(idx))
    print(f'values =\n{values}')

    obj.numbaFunction(idx, values)
    print(f'A =\n{obj.A}')

Итак, мои вопросы:

  • Кто-нибудь знает решение этой индексации в numba или есть другое векторизованное решение?
  • Есть ли более быстрое решение для nonNumbaFunction1?

Может быть полезно знать, что вставленный массив мал (от 4x4 до 10x10), но эта индексация появляется во вложенных циклах, поэтому она также должна быть тихой и быстрой! Позже мне понадобится аналогичная индексация и для трехмерных объектов.

1 Ответ

0 голосов
/ 15 января 2019

Из-за ограничений поддержки индексирования в numba, я не думаю, что вы можете сделать что-то лучше, чем сами писать циклы for. Чтобы сделать его универсальным для измерений, вы можете использовать декоратор generated_jit для специализации. Примерно так:

def set_2d(target, values, idx):
    for i in range(values.shape[0]):
        for j in range(values.shape[1]):
            target[idx[i], idx[j]] = values[i, j]

def set_3d(target, values, idx):
    for i in range(values.shape[0]):
        for j in range(values.shape[1]):
            for k in range(values.shape[2]):
                target[idx[i], idx[j], idx[k]] = values[i, j, l]

@numba.generated_jit
def set_nd(target, values, idx):
    if target.ndim == 2:
        return set_2d
    elif target.ndim == 3:
        return set_3d

Тогда это может быть использовано в вашем jitclass

specs = OrderedDict()
specs['A'] = float64[:, :]

@jitclass(specs)
class NumbaClass(object):
    def __init__(self, n, m):
        self.A = np.zeros((n, m))
    def numbaFunction(self, idx, values):
        set_nd(self.A, values, idx)
...