Есть ли способ выполнить несколько проверок столбцов, используя индексирование массива numpy? - PullRequest
1 голос
/ 13 апреля 2019

У меня есть двумерный массив данных, и я пытаюсь эффективно обрезать плохие столбцы из этих данных. Я пытаюсь удалить все столбцы, которые содержат значение 0, у которых абсолютная разница больше 12 между минимальным и максимальным значениями или которые содержат значение больше 9,5.

Код, который у меня есть, работает, но он довольно медленный. Насколько я понимаю, на заднем плане есть цикл над моим массивом для каждой из этих строк кода. Мне интересно, есть ли способ уменьшить это до одного цикла.

import numpy as np

data_array = data_array[:,abs(data_array).min(0)!=0]
data_array = data_array[:,abs(data_array.min(0)-data_array.max(0)) < 12]
data_array = data_array[:,abs(data_array).max(0) < 9.5]

1 Ответ

0 голосов
/ 13 апреля 2019

Я думаю, что невозможно выполнить эти три проверки в одном цикле.

Вероятно, вы улучшите производительность, правильно упорядочив операции обрезки. Действительно, вы должны сначала проверить условие, которое удаляет большинство столбцов, чтобы массив, передаваемый второму фильтру, был как можно меньше. Тот же критерий применяется к остальным фильтрам.

Согласно комментариям ваши данные варьируются от -30 до 30. Можно ожидать, что чаще всего недопустимые столбцы - это столбцы, значения которых превышают 9.5. Я также предполагаю, что наименее частой причиной, по которой столбцы отбрасываются, является наличие нулевого значения. Если эти предположения не верны, вы должны соответствующим образом изменить порядок фильтров. Дальнейшее улучшение может быть достигнуто путем удаления ненужных вызовов функций (таких как abs).

Следующие функции реализуют те же операции фильтрации в другом порядке, как описано выше:

import numpy as np

def trim(x, low=0, high=9.5, diff=12):
    x = x[:, np.all(x != 0, axis=0)]
    x = x[:, np.ptp(x, axis=0) <= diff]
    x = x[:, np.all(x <= high, axis=0)]
    return x

def trim_reordered(x, low=0, high=9.5, diff=12):
    x = x[:, np.all(x <= high, axis=0)]
    x = x[:, np.ptp(x, axis=0) <= diff]
    x = x[:, np.all(x != 0, axis=0)]
    return x

Демо

In [205]: np.random.seed(213)

In [206]: small_arr = np.random.randint(low=-30, high=30, size=(3, 10))

In [207]: small_arr
Out[207]: 
array([[ 13,   6,   2, -29,  13,  11, -12, -24,   5,   9],
       [ 29,  24,  16, -21, -27,  -5,  -5, -16,  21, -29],
       [-10,  10, -24, -10,   4,   0,  -8, -23,   0,   4]])

In [208]: trim(small_arr)
Out[208]: 
array([[-12, -24],
       [ -5, -16],
       [ -8, -23]])

In [209]: large_arr = np.random.randint(low=-30, high=30, size=(10, 10**6))

In [210]: %timeit trim(large_arr)
77.3 ms ± 470 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [211]: %timeit trim_reordered(large_arr)
16.1 ms ± 174 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [212]: np.all(trim(large_arr) == trim_reordered(large_arr))
Out[212]: True
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...