Извлечение определенных строк в массив Numpy с помощью Numba - PullRequest
2 голосов
/ 27 октября 2019

У меня есть следующий массив:

import numpy as np
from numba import njit


test_array = np.random.rand(4, 10)

Я создаю «объединенную» функцию, которая разрезает массив и выполняет некоторые операции после этого:

@njit(fastmath = True)
def test_function(array):

   test_array_sliced = test_array[[0,1,3]]

   return test_array_sliced

Однако Numba выдает следующееошибка:

In definition 11:
    TypeError: unsupported array index type list(int64) in [list(int64)]
    raised from /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/numba/typing/arraydecl.py:71
This error is usually caused by passing an argument of a type that is unsupported by the named function.

Обходной путь

Я попытался удалить ненужные мне строки с помощью np.delete, но, поскольку мне нужно указать axis, Numba выдает следующую ошибку:

@njit(fastmath = True)
def test_function(array):

   test_array_sliced = np.delete(test_array, obj = 2, axis = 0)

   return test_array_sliced

In definition 1:
    TypeError: np_delete() got an unexpected keyword argument 'axis'
    raised from /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/numba/typing/templates.py:475
This error is usually caused by passing an argument of a type that is unsupported by the named function.

Есть идеи, как извлечь определенные строки в Numba?

Ответы [ 2 ]

2 голосов
/ 27 октября 2019

Я думаю, что это сработает (кажется, что это так в документации ), если вы индексируете массив вместо списка:

test_array_sliced = array[np.array([0,1,3])]

(я изменил массив, который выВы нарезаете на array, то есть то, что вы передаете функции. Возможно, это было сделано намеренно, но будьте осторожны с глобальными переменными!)

1 голос
/ 27 октября 2019

Numba не поддерживает необычную индексацию. Я не уверен на 100%, как выглядит ваш реальный пример использования, но простой способ сделать это будет выглядеть примерно так:

import numpy as np
import numba as nb

@nb.njit
def test_func(x):
    idx = (0, 1, 3)
    res = np.empty((len(idx), x.shape[1]), dtype=x.dtype)
    for i, ix in enumerate(idx):
        res[i] = x[ix]

    return res

test_array = np.random.rand(4, 10)
print(test_array)
print()
print(test_func(test_array))

Редактировать: @kwinkunks - это правильно, и мой первоначальный ответ сделал неверное общее утверждение о том, что модная индексация не поддерживается. Это в ограниченном числе случаев, включая этот.

...