Numpy: Проверка элементов массива выше некоторого значения, х раз подряд? - PullRequest
0 голосов
/ 09 ноября 2018

Я пытаюсь проверить, не превышают ли значения в массиве некоторое значение подряд несколько раз.

Например

arr1 = np.array([1,2,1,3,4,5,6,7])
arr2 = np.array([1,2,1,3,4,2,6,7])

Скажем, я хочу проверить, имеет ли элемент в массиве >=3 в течение четырех последовательных периодов. Тест выдаст true для arr1, но false для arr2.

Ответы [ 3 ]

0 голосов
/ 09 ноября 2018

Вот низкотехнологичный, но быстрый метод. Создайте логический массив, сформируйте cumsum () и сравните каждый элемент с одним n местами. Если разница равна n, это должна быть полоса True s.

def check_streak(a, th, n):
    ps = (a>=th).cumsum()
    return (ps[n:]-ps[:ps.size-n] == n).any()
0 голосов
/ 09 ноября 2018

Другое решение (но медленнее других)

import numpy as np
from numpy.lib.stride_tricks import as_strided

def f(arr, threshold=3, n=4):
    arr = as_strided(arr, shape=(arr.shape[0]-n+1, n), strides=2*arr.strides)
    return (arr >= threshold).all(axis=1).any()


# How it works:
# arr = np.array([1, 2, 3, 4, 5, 6, 7, 8])
# n = 4
# threshold = 3

# arr = as_strided(arr, shape=(arr.shape[0]-n+1, n), strides=2*arr.strides)
# print(arr)
# [[1 2 3 4]
#  [2 3 4 5]
#  [3 4 5 6]
#  [4 5 6 7]
#  [5 6 7 8]]

# print(arr >= threshold)
# [[False False  True  True]
#  [False  True  True  True]
#  [ True  True  True  True]
#  [ True  True  True  True]
#  [ True  True  True  True]]

# print((arr >= threshold).all(axis=1))
# [False False  True  True  True]

# print((arr >= threshold).all(axis=1).any())
# True
0 голосов
/ 09 ноября 2018

Вот один способ с convolution -

def cross_thresh_convolve(arr, thresh, N):
    # Detect if arr crosses thresh for N consecutive times anywhere
    return (np.convolve(arr>=thresh,np.ones(N,dtype=int))==N).any()

В качестве альтернативы с binary-dilation -

from scipy.ndimage.morphology import binary_erosion

def cross_thresh_erosion(arr, thresh, N):
    return binary_erosion(arr>=thresh, np.ones(N)).any()

Пробные прогоны -

In [43]: arr1 = np.array([1,2,1,3,4,5,6,7])
    ...: arr2 = np.array([1,2,1,3,4,2,6,7])

In [44]: print cross_thresh_convolve(arr1, thresh=3, N=4)
    ...: print cross_thresh_erosion(arr1, thresh=3, N=4)
    ...: print cross_thresh_convolve(arr2, thresh=3, N=4)
    ...: print cross_thresh_erosion(arr2, thresh=3, N=4)
True
True
False
False

Общие сравнения

Чтобы охватить общие сравнения, скажем, хотим ли мы искать greater или less-than или даже просто сравнить на равенство сзначение, мы могли бы использовать встроенные функции сравнения NumPy, чтобы заменить часть arr>=thresh из более ранних решений и, следовательно, дать себе общие реализации, например, так: *

def consecutive_comp_convolve(arr, comp, N, comparison=np.greater_equal):
    return (np.convolve(comparison(arr,comp),np.ones(N,dtype=int))==N).any()

def consecutive_comp_erosion(arr, comp, N, comparison=np.greater_equal):
    return binary_erosion(comparison(arr,comp), np.ones(N)).any()

Следовательно, наш конкретный пример запуска будет -

consecutive_comp_convolve(arr1, comp=3, N=4, comparison=np.greater_equal)
consecutive_comp_erosion(arr1, comp=3, N=4, comparison=np.greater_equal)
consecutive_comp_convolve(arr2, comp=3, N=4, comparison=np.greater_equal)
consecutive_comp_erosion(arr2, comp=3, N=4, comparison=np.greater_equal)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...