Numba jit с пользовательским аргументом numpy.dtype компилируется только тогда, когда атрибут dtype "flatten" - PullRequest
0 голосов
/ 08 февраля 2019

Может кто-нибудь объяснить мне, почему это работает:

import numpy as np
import numba

mytype_dtype = np.dtype([
        ('vector3', np.float32, 3),
        ('color', np.float32, 3),
        ('radius', np.float32, 1)
        ])


t = np.array([((1.,2.,3.),(0.1,0.2,0.3), 1.)], dtype=mytype_dtype)

@numba.njit
def my_func(a,b,v):
    print(v[0]['radius'])
    col = v[0]['color'].flatten()
    vec = v[0]['vector3'].flatten()
    if a>0. :
        return vec
    else:
        return col

И это не удается:

import numpy as np
import numba

mytype_dtype = np.dtype([
        ('vector3', np.float32, 3),
        ('color', np.float32, 3),
        ('radius', np.float32, 1)
        ])


t = np.array([((1.,2.,3.),(0.1,0.2,0.3), 1.)], dtype=mytype_dtype)

@numba.njit
def my_func(a,b,v):
    print(v[0]['radius'])
    col = v[0]['color']
    vec = v[0]['vector3']
    if a>0. :
        return vec
    else:
        return col

Я получаю исключение от Numba во время компиляции:

LoweringError: Can only insert float at [0] in [3 x float]: got i8*

Единственное отличие:

col = v[0]['color'].flatten()
vec = v[0]['vector3'].flatten()

внутри my_func

И если я проверю тип, оба будут class 'numpy.ndarray'

print(type(t[0]['color']))
print(type(t[0]['color'].flatten()))

Заранее спасибо

...