Изменение знака из n элементов, ближайших к значению в массиве Numpy - PullRequest
0 голосов
/ 03 апреля 2019

Я хочу изменить знак n элементов в массиве numpy, который лежит ближе всего к определенному значению, но не меньше. То есть элементы должны быть равны или превышать значение. Существуют ли быстрые методы Numpy, которые могут сделать это эффективно с большими массивами?

Код, который у меня сейчас есть, принимает n значений, которые больше или равны, но не самые близкие, что «хорошо», но не идеально для моих результатов.

def update(arr, n, value):
    updated = 0
    i = 0
    while updated < n:
        if arr[i] >= value: # just a random value above "value"
            arr[i] = -arr[i]
            updated +=1
        i += 1

arr = np.array([9, 8, 2, -4, 3, 4])
n = 3
value = 2
update(arr, n, value)

дает мне

arr = np.array([-9, -8, -2, -4, 3, 4])

когда я вместо этого хочу

arr = np.array([9, 8, -2, -4, -3, -4])

Ответы [ 3 ]

0 голосов
/ 03 апреля 2019

Вы можете использовать argpartition:

arr = np.random.random(20)
value = 0.5
n = 4

nl = np.count_nonzero(arr<value)
closest = np.argpartition(arr, (nl, nl+n-1))[nl:nl+n]
arr[closest] = -arr[closest]
arr
# array([ 0.33697627,  0.42607914, -0.63703314, -0.57517234,  0.82674228,
#        -0.52929285,  0.64776714,  0.25609886,  0.24681445,  0.2486823 ,
#         0.76740245,  0.02368603,  0.21498096, -0.51033841,  0.19901665,
#         0.30939207,  0.69036139,  0.83178506,  0.97243443,  0.47620492])
0 голосов
/ 03 апреля 2019

Это должно работать:

def flip_some(a, n, value):
    more_than = (a >= value)
    first_n_elements = (a < np.sort(a[more_than])[n])
    return np.where(more_than & first_n_elements, -a, a)

print(flip_some(np.array([9, 8, 2, -4, 3, 4]), 3, 2))
print(flip_some(np.arange(10), 2, 5))

Вывод:

[ 9 -8  -2 -4 -3 -4]
[ 0  1  2  3  4 -5 -6 -7  8  9]
0 голосов
/ 03 апреля 2019

Я не обновляю массив на месте, но я бы сделал что-то вроде:

def update(arr, n, value):
    arr_copy = arr.copy()
    diffs = arr - value
    absolute_diffs = np.abs(diffs)
    update_indeces = np.argpartition(absolute_diffs, n)[:n]
    arr_copy[update_indeces] *= -1
    return arr_copy
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...