Python / Numba - пользовательский объект класса в качестве типа ввода - PullRequest
0 голосов
/ 22 мая 2018

Я начинаю с numba, и моя первая цель - попытаться ускорить не столь сложную функцию с помощью вложенного цикла.

Имеется следующий класс:

class TestA:
    def __init__(self, a, b):
        self.a = a
        self.b = b

    def get_mult(self):
        return self.a * self.b

и numpy ndarray, который содержит объекты класса TestA.Размер (N,), где N обычно ~ 3 миллиона в длину.

Теперь, учитывая следующую функцию:

def test_no_jit(custom_class_obj_container):
    container_length = len(custom_class_obj_container)
    sum = 0
    for i in range(container_length):
        for j in range(i + 1, container_length):
            obj_i = custom_class_obj_container[i]
            obj_j = custom_class_obj_container[j]
            sum += (obj_i.get_mult() + obj_j.get_mult())

    return sum

Я попытался поиграть numba, чтобы заставить его работать с вышеуказанной функцией, однако я не могу заставить ее работать сnopython=True флаг, и если он установлен в false, то время выполнения выше, чем у функции no-jit.

Вот моя последняя попытка jit функции (также использующая nb.prange):

@nb.jit(nopython=False, parallel=True)
def test_jit(custom_class_obj_container):
    container_length = len(custom_class_obj_container)
    sum = 0
    for i in nb.prange(container_length):
        for j in nb.prange(i + 1, container_length):
            obj_i = custom_class_obj_container[i]
            obj_j = custom_class_obj_container[j]
            sum += (obj_i.get_mult() + obj_j.get_mult())

    return sum

Я пытался искать, но не могу найтиучебник о том, как определить пользовательский класс в сигнатуре, и как мне поступить, чтобы ускорить функцию такого рода и заставить ее работать на GPU и, возможно, (любая информация по этому вопросу будет высоко цениться), чтобы получить ееработать с cuda библиотеками - которые установлены и готовы к использованию (ранее использовались с tensorflow)

1 Ответ

0 голосов
/ 08 ноября 2018

В документах numba приведен пример создания пользовательского типа, даже для режима nopython: https://numba.pydata.org/numba-doc/latest/extending/interval-example.html

В вашем случае, если, конечно, это действительно упрощенная версия того, что вы действительно хотите сделать,кажется, что самый простой подход - это повторно использовать существующие типы.Кроме того, создание массива объектов длиной 3M будет медленным и приведет к фрагментации памяти (поскольку объекты не хранятся в смежных блоках).

Пример использования массивов записей длярешить проблему:

x_dt = np.dtype([('a', np.float64),
                 ('b', np.float64)])
n = 30000
buf = np.arange(n*2).reshape((n, 2)).astype(np.float64)
vec3 = np.recarray(n, dtype=x_dt, buf=buf) 

@numba.njit
def mult(a):
    return a.a * a.b

@numba.jit(nopython=True, parallel=True)
def sum_of_prod(vector):
    sum = 0
    vector_len = len(vector)
    for i in numba.prange(vector_len):
        for j in numba.prange(i + 1, vector_len):
            sum += mult(vector[i]) + mult(vector[j])
    return sum

sum_of_prod(vec3)

FWIW, я не эксперт по numba.Я нашел этот вопрос, когда искал, как реализовать пользовательский тип в Numba для нечисловых вещей.В вашем случае, поскольку это очень числовой тип, я думаю, что пользовательский тип, вероятно, излишний.

...