Фильтровать по индексу и сводить в ноль, как tf.sequence_mask - PullRequest
0 голосов
/ 15 января 2019

Я хотел бы отфильтровать мой массив 2D по индексу, а затем выровнять этот массив только по значениям в фильтре. Это в значительной степени то, что сделал бы tf.sequence_mask, но мне нужно это в numpy или другой легкой библиотеке.

Спасибо!

PD: Это пример:

array_2d = [[0,1,2,3,4,5],[8,9,10,11,12,0],[21,22,21,0,0,0]] # this is a numpy array
array_len = [6,5,3]
expected_output = [0,1,2,3,4,5,8,9,10,11,12,21,22,21]

Ответы [ 2 ]

0 голосов
/ 15 января 2019

Вот решение vectorized, использующее логическую маску для индексации array_2d:

array_2d = np.array([[0,1,2,3,4,5],[8,9,10,11,12,0],[21,22,21,0,0,0]]) 
array_len = [6,5,3]

m = ~(np.ones(array_2d.shape).cumsum(axis=1).T > array_len).T
array_2d[m]
array([ 0,  1,  2,  3,  4,  5,  8,  9, 10, 11, 12, 21, 22, 21])

Подробности

Маска создается, взяв cumsum поверх ndarray ones той же формы, что и array_2d, и выполнив построчное сравнение, чтобы увидеть, какие элементы больше array_len.

Итак, первый шаг - создать следующее ndarray:

np.ones(array_2d.shape).cumsum(axis=1)

array([[1., 2., 3., 4., 5., 6.],
       [1., 2., 3., 4., 5., 6.],
       [1., 2., 3., 4., 5., 6.]])

И выполнить рядовое сравнение с array_len:

~(np.ones(array_2d.shape).cumsum(axis=1).T > array_len).T

array([[ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True, False],
       [ True,  True,  True, False, False, False]])

Тогда вам просто нужно отфильтровать массив с помощью:

array_2d[m]
array([ 0,  1,  2,  3,  4,  5,  8,  9, 10, 11, 12, 21, 22, 21])
0 голосов
/ 15 января 2019

Это один из способов с использованием логической маскировки и применения ее к вашему сплющенному array_2d

array_2d = np.array([[0,1,2,3,4,5],[8,9,10,11,12,0],[21,22,21,0,0,0]]) 
array_len = [6,5,3]

# Create a boolean mask
mask = np.zeros((array_2d.shape), dtype=bool)

# Change to True for elements to be kept
for i, j in enumerate(array_len):
        mask[i][0:j] = True

expected_output = array_2d.flatten()[mask.flatten()]

выход

array([ 0,  1,  2,  3,  4,  5,  8,  9, 10, 11, 12, 21, 22, 21])
...