numpy найти срез по осям, где встречаются первое и последнее встречающиеся значения - PullRequest
3 голосов
/ 06 мая 2020

У меня есть массив 3D numpy с целочисленными значениями, что-то определенное как:

import numpy as np
x = np.random.randint(0,  100, (10, 10, 10))

Теперь я хочу найти последний фрагмент (или, альтернативно, первый фрагмент) по заданным осям (скажем, 1), где встречается конкретное значение. На данный момент я делаю что-то вроде:

first=None
last=None
val = 20

for i in range(len(x.shape[1]):
    slice = x[:, i, :]
    if len(slice[slice==val]) > 0:
        if not first:
            first = i
        last = i

return first, last

Это кажется немного непонятным c, и мне интересно, есть ли numpy magi c, чтобы сделать это?

Ответы [ 2 ]

4 голосов
/ 06 мая 2020

Вы, вероятно, можете оптимизировать это, чтобы быть быстрее, но вот векторизованная версия того, что вы ищете:

axis = 1
mask = np.where(x==val)[axis]
first, last = np.amin(mask), np.amax(mask)

Сначала он находит элемент val в вашем массиве с помощью np.where и возвращает min и max индексов по заданной оси.

1 голос
/ 06 мая 2020

На свой вопрос вы хотите проверить, есть ли такой допустимый фрагмент, и, следовательно, получить индексы start / first, stop / last. При отсутствии такого допустимого фрагмента мы должны вернуть None. Это требует дополнительной проверки. Кроме того, мы можем использовать masking для эффективного получения этих индексов, например:

def slice_info(x, val):
    n = (x==val).any((0,2))
    if n.any():
        return n.argmax(), len(n)-n[::-1].argmax()-1
    else:
        return None,None

Бенчмаркинг

Другие предлагаемые решения:

# https://stackoverflow.com/a/61629916/ @Ehsan
def where_amin_amax(x, val):
    axis = 1
    mask = np.where(x==val)[axis]
    first, last = np.amin(mask), np.amax(mask)
    return first, last

Сроки -

# Same setup as in given sample
In [157]: np.random.seed(0)
     ...: x = np.random.randint(0,  100, (10, 10, 10))

In [158]: %timeit where_amin_amax(x, val=20)
     ...: %timeit slice_info(x, val=20)
15.1 µs ± 287 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
9.63 µs ± 43.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

# Bigger
In [159]: np.random.seed(0)
     ...: x = np.random.randint(0,  100, (100, 100, 100))

In [160]: %timeit where_amin_amax(x, val=20)
     ...: %timeit slice_info(x, val=20)
3.34 ms ± 31.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
691 µs ± 3.69 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...