Удалить строку в numpy.array в numba - PullRequest
0 голосов
/ 04 декабря 2018

Я впервые публикую здесь что-нибудь.Я пытаюсь удалить строку внутри массива NumPy внутри JumbClass.Я написал следующий код, чтобы удалить любую строку, содержащую 3:

>>> a = np.array([[1,2,3,4],[5,6,7,8]])

>>> a

>>> array([[1, 2, 3, 4],
       [5, 6, 7, 8]])

>>> i = np.where(a==3)

>>> i

>>> (array([0]), array([2]))

Я не могу использовать функцию numpy.delete (), поскольку она не поддерживается numba и не может назначить значение типа None для строки.Все, что я мог сделать, это назначить 0 для строки следующим образом:

>>> a[i[0]] = 0

>>> a

>>> array([[0, 0, 0, 0],
       [5, 6, 7, 8]])

Но я хочу полностью удалить строку.

Любая помощь будет оценена.

Спасибовам очень нравится.

Ответы [ 3 ]

0 голосов
/ 04 декабря 2018

Добро пожаловать в Stacoverflow.Вы можете просто использовать нарезку массива, чтобы выбрать только те строки, в которых нет 3.Приведенный ниже код немного сложен, чтобы в основном охватить дополнительные детали для вас, хотя у вас может быть гораздо более короткая версия с опущенными ненужными строками.Назначение клавиш: rows_final = [x for x in range(a.shape[0]) if x not in rows3]

Код:

import numpy as np

a = np.array([[1,2,3,4],[5,6,7,8],[10,11,3,13]])

ind = np.argwhere(a==3)
rows3 = ind[0]
cols3 = ind[1]

print ("Initial Array: \n", a)
print()
print("rows, cols of a==3 : ", rows3, cols3)

rows_final = [x for x in range(a.shape[0]) if x not in rows3]
a_final = a[rows_final,:]

print()
print ("Final Rows: \n", rows_final)
print ("Final Array: \n", a_final)

Выход:

Initial Array: 
 [[ 1  2  3  4]
 [ 5  6  7  8]
 [10 11  3 13]]

rows, cols of a==3 :  [0 2] [2 2]

Final Rows: 
 [1]
Final Array: 
 [[5 6 7 8]]
0 голосов
/ 06 декабря 2018

Это на самом деле непростая задача, поскольку numba имеет следующие ограничения:

  • нет поддержки np.delete
  • нет поддержки ключевого слова axis в np.all и np.any
  • нет поддержки индексации двумерных массивов (по крайней мере, без масок bool)
  • нет или затруднено прямое создание масок bool с np.zeros(shape, dtype=np.bool) или подобными функциями

Но все же есть несколько подходов, которые вы можете использовать для решения вашей проблемы.Я протестировал несколько из них, и создание логической маски представляется наиболее быстрым и чистым способом.

@nb.njit
def delete_workaround(arr, num):
    mask = np.zeros(arr.shape[0], dtype=np.int64) == 0
    mask[np.where(arr == num)[0]] = False
    return arr[mask]

a = np.array([[1,2,3,4],[5,6,7,8]])

delete_workaround(a, 3)

Это решение также обладает огромным преимуществом в сохранении размеров массива , даже когда только одинстрока или пустой массив возвращается.Это важно для jitclasses , так как jitclasses сильно зависят от фиксированных размеров.

Поскольку вы запросите это, я покажу вам решение, которое преобразует массивы в списки и обратно.Поскольку отраженные списки еще не поддерживаются всеми методами python в numba, вам придется использовать оболочку для некоторых частей функции:

@nb.njit
def delete_lrow(arr_list, num):
    idx_list = []
    for i in range(len(arr_list)):
        if (arr_list[i] != num).all():
            idx_list.append(i)
    res_list = [arr_list[i] for i in idx_list]
    return res_list

def wrap_list_del(arr, num):
    arr_list = list(arr)
    return np.array(delete_lrow(arr_list, num))

arr = np.array([[1,2,3,4],[5,6,7,8],[10,11,5,13],[10,11,3,13],[10,11,99,13]])
arr2 = np.random.randint(0, 256, 100000*4).reshape(-1, 4)

%timeit delete_workaround(arr, 3)
# 1.36 µs ± 128 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%timeit wrap_list_del(arr, 3)    
# 69.3 µs ± 4.97 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%timeit delete_workaround(arr2, 3)
# 1.9 ms ± 68.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit wrap_list_del(arr2, 3)
# 1.05 s ± 103 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Так что придерживаться массивов, если у вас уже есть массивы (и даже еслиу вас еще нет массивов, но ваши данные имеют согласованный тип) примерно в 50 раз быстрее для небольших массивов и примерно в 550 раз быстрее для больших массивов .Это то, что нужно помнить: для работы с числовыми данными существуют массивы Numpy!Numpy сильно оптимизирован для работы с числовыми данными!Абсолютно бесполезно преобразовывать массивы числовых данных в другой «формат», если тип данных (dtype) является постоянным, и никакой сверхспециальный материал не требует этого (я едва сталкивался с такой ситуацией).
И это особенно верно для кода, оптимизированного для Numba!Numba в значительной степени полагается на число и постоянство dtypes / формы и т. Д. Еще больше, если вы хотите работать с классными играми.

0 голосов
/ 04 декабря 2018

Я думаю, вам нужно снова назначить удаление переменной a, это сработало для меня.Попробуйте следующий код:

import numpy as np
a = np.array([[1,2,3,4],[5,6,7,8]])
print(a)
i = np.where(a==3)
a=np.delete(a, i, 0) # assign it back to the variable
print(a) 
...