Как получить подмножество диапазона, маскируемого логическим массивом NumPy? - PullRequest
0 голосов
/ 21 апреля 2020

Мне нужно сделать целочисленную индексацию массива непрерывного целочисленного диапазона [0, n), значения которого всегда должны игнорироваться. Игнорируемые значения не должны появляться в результате. И есть отдельный NumPy логический массив длины n (т.е. маска), который указывает, игнорируется ли элемент исходного диапазона.

В чистом Python, Я бы написал так:

def get_non_masked_indices(range_mask, indices):
    return [i for i in indices if not range_mask[i]]

Для этого ввода

#                0  1  2  3  4  5  6  7  8  9
mask = np.array([0, 1, 0, 1, 0, 0, 1, 1, 1, 0], dtype=np.bool)
idxs = np.array([      2, 3, 4,       7,    9])
#                      +  -  +        -     +

результат вызова get_non_masked_indices(mask, idxs) будет

[2, 4, 9]

Это часто используется шаблон обработки массива (особенно в алгоритмах графа). Есть ли функция NumPy, чтобы облегчить это?

До сих пор я пришел со следующими опциями:

  • Native NumPy индексация
  • Маскировка с индексированная маска
  • индексирование маскированного диапазона

собственный NumPy индексирование:

return indices[np.logical_not(range_mask[indices])]

маскирование с помощью индексированной маски:

return np.ma.MaskedArray(indices, range_mask[indices]).compressed()
return np.ma.masked_where(range_mask[indices], indices).compressed()

Индексирование маскированного диапазона:

return np.ma.MaskedArray(np.arange(len(range_mask)), range_mask)[indices].compressed()
return np.ma.masked_where(range_mask, np.arange(len(range_mask)))[indices].compressed()

Пример из приложения

Предположим, у нас есть график, представленный в виде списка NumPy массивы смежных узлов.

adjacent_nodes = [
  np.array([1, 2]),
  np.array([0]),
  np.array([0]),
]
is_colored = np.array([False, False, True])

graph coloring

Функция, которая меня интересует, должна возвращать только неокрашенных соседей узла:

get_non_masked_indices(is_colored, adjacent_nodes[0])  # -> [1]
get_non_masked_indices(is_colored, adjacent_nodes[1])  # -> [0]
get_non_masked_indices(is_colored, adjacent_nodes[2])  # -> [0]
...