Использование numpy.where для предотвращения выхода за пределы - PullRequest
0 голосов
/ 18 сентября 2018

Я пытаюсь найти значения в массиве на основе массива индексов. Этот массив индексов может содержать индексы, которые могут быть вне пределов. В этом случае я хочу вернуть определенное значение (здесь 0).

(я мог бы использовать цикл for, но это было бы слишком медленно.)

Итак, я делаю это:

data = np.arange(1000).reshape(10, 10, 10)
i = np.arange(9).reshape(3, 3)
i[0, 0] = 10
condition = (i[:, 0] < 10) & (i[:, 1] < 10) & (i[:, 2] < 10)
values = np.where(condition, data[i[:, 0], i[:, 1], i[:, 2]], 0)

Однако я все еще получаю сообщение об ошибке:

IndexError: index 10 is out of bounds for axis 0 with size 10

Я полагаю, это потому, что вторые параметры не лениво вычисляются и вычисляются перед вызовом функции.

Есть ли в numpy решение для доступа к массиву на основе условия, но при этом сохранить порядок? Сохраняя порядок, я имею в виду, что не могу сначала отфильтровать массив, потому что могу потерять порядок в конечном результате. В конце, в этом конкретном примере, я все еще хочу, чтобы массив значений содержал 0, когда индексы выходят за пределы. Таким образом, конечный результат будет:

array([ 0, 345, 678])

Ответы [ 2 ]

0 голосов
/ 18 сентября 2018

В каждом столбце индексного массива хранятся индексы для каждого измерения. Мы могли бы сгенерировать маску действительных (по границам) и установить в ней недействительные нули. то есть случаи за пределами границ будут проиндексированы с помощью [0,0,0], затем разрешится индексирование массива этой измененной версией и, наконец, снова используйте маску для сброса недействительных, например, так -

shp = data.shape
valid_mask = (i < shp).all(1)
i[~valid_mask] = 0
out = np.where(valid_mask,data[tuple(i.T)],0)

Модифицированная компактная версия того же самого без изменения i, будет -

out = np.where(valid_mask,data[tuple(np.where(valid_mask,i.T,0))],0)
0 голосов
/ 18 сентября 2018

Сначала индекс, затем примените исправление к правильным значениям.

shp = np.array(data.shape)
j = i % shp 
res = data[j.T.tolist()]
res[(i >= shp).nonzero()[0]] = 0

print(res)
array([  0, 345, 678])
...