numba выдает ошибку при изменении формы массива numpy - PullRequest
0 голосов
/ 05 апреля 2020

Я пытаюсь оптимизировать некоторый код, который имеет несколько циклов и матричных операций. Однако я сталкиваюсь с некоторыми ошибками. Пожалуйста, найдите код и вывод ниже.

Код:

@njit
def list_of_distance(d1): #d1 was declared as List()
    list_of_dis = List()
    for k in range(len(d1)):
        sum_dist = List()
        for j in range(3):
            s = np.sum(square(np.reshape(d1[k][:,:,j].copy(),d1[k][:,:,j].shape[0]*d1[k][:,:,j].shape[1]))) 
            sum_dist.append(s) # square each value in the resulting list (dimenstion)   
        distance = np.sum(sum_dist) # adding the total value for each dimension to a list
        list_of_dis.append(np.round(np.sqrt(distance)))  # Sum the values to get the total squared values of residual images 

    return list_of_dis

Вывод:

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<function sum at 0x7f898814bd08>) with argument(s) of type(s): (list(int64))
 * parameterized
In definition 0:
    All templates rejected with literals.
In definition 1:
    All templates rejected without literals.
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: resolving callee type: Function(<function sum at 0x7f898814bd08>)
[2] During: typing of call at <ipython-input-18-8c787cc8deda> (7)


File "<ipython-input-18-8c787cc8deda>", line 7:
def list_of_distance(d1):
    <source elided>
        for j in range(3):
            s = np.sum(square(np.reshape(d1[k][:,:,j].copy(),d1[k][:,:,j].shape[0]*d1[k][:,:,j].shape[1]))) 
            ^

This is not usually a problem with Numba itself but instead often caused by
the use of unsupported features or an issue in resolving types.

To see Python/NumPy features supported by the latest release of Numba visit:
http://numba.pydata.org/numba-doc/latest/reference/pysupported.html
and
http://numba.pydata.org/numba-doc/latest/reference/numpysupported.html

For more information about typing errors and how to debug them visit:
http://numba.pydata.org/numba-doc/latest/user/troubleshoot.html#my-code-doesn-t-compile

If you think your code should work with Numba, please report the error message
and traceback, along with a minimal reproducer at:
https://github.com/numba/numba/issues/new

Может ли кто-нибудь помочь мне с этим вопросом.

Спасибо и наилучшими пожеланиями

Майкл

1 Ответ

1 голос
/ 06 апреля 2020

Мне пришлось внести несколько изменений, чтобы заставить это работать, и смоделировать «d1», но это работает для меня с Numba. Эта основная проблема, которая вызвала ошибку времени выполнения, заключается в том, что np.sum не работает в списке с Numba, хотя он работал правильно, когда я закомментировал @jit. Обертывание sumdist с помощью np.array () решает эту проблему.

d1 = [np.arange(27).reshape(3,3,3), np.arange(27,54).reshape(3,3,3)]

@njit
def list_of_distance(d1): #d1 was declared as List()
    list_of_dis = [] #List() Changed - would not compile
    for k in range(len(d1)):
        sum_dist = [] #List() #List() Changed - would not compile
        for j in range(3):
            s = np.sum(np.square(np.reshape(d1[k][:,:,j].copy(),d1[k][:,:,j].shape[0]*d1[k][:,:,j].shape[1]))) #Added np. to "square"
            sum_dist.append(s) # square each value in the resulting list (dimenstion)   
        distance = np.sum(np.array(sum_dist)) # adding the total value for each dimension to a list - Wrapped list in np.array
        list_of_dis.append(np.round(np.sqrt(distance)))  # Sum the values to get the total squared values of residual images 

    return list_of_dis

list_of_distance(d1)
Out[11]: [79.0, 212.0]
...