Индексирование numpy массив со списком слайсов - PullRequest
1 голос
/ 30 января 2020

У меня есть список slices, и я использую их для индексации массива numpy.

arr = np.arange(25).reshape(5, 5)
# array([[ 0,  1,  2,  3,  4],
#        [ 5,  6,  7,  8,  9],
#        [10, 11, 12, 13, 14],
#        [15, 16, 17, 18, 19],
#        [20, 21, 22, 23, 24]])

slice_list = list(map(lambda i: slice(i, i+2), [1, 2]))
# [slice(1, 3, None), slice(2, 4, None)]

print(arr[slice_list])
# == arr[1:3, 2:4]
# [[ 7  8]
#  [12 13]]

Это нормально работает, но ломается, если у меня меньше кусков, чем количество измерений массива Я хочу индексировать.

arr3d = arr[np.newaxis, :, :]  # dims: [1, 5, 5]
arr3d[:, slice_list]
# IndexError: only integers, slices (`:`), ellipsis (`...`),(`None`) 
#  numpy.newaxis and integer or boolean arrays are valid indices

Однако работают следующие примеры:

arr3d[:, slice_list[0], slice_list[1]]
arr3d[[slice(None)] + slice_list]
arr3d[:, [[1], [2]], [2, 3]]

Можно ли использовать список фрагментов для индексации массива с большим количеством измерений. Я хочу сделать что-то вроде:

arr[..., slice_list]
arr[..., slice_list, :]
arr[:, slice_list, :]

, не думая о размерах массива и не выясняя, сколько [slice(None)]*X мне нужно набить с обеих сторон моего slice_list.

1 Ответ

1 голос
/ 30 января 2020

Вы можете сделать это, используя наборы срезов и объекты с многоточием. Просто поместите все элементы, которые вы хотите использовать для индексации, в кортеж и используйте его в качестве индекса:

import numpy as np

arr = np.arange(24).reshape(2, 3, 4)
print(arr)
# [[[ 0  1  2  3]
#   [ 4  5  6  7]
#   [ 8  9 10 11]]
# 
#  [[12 13 14 15]
#   [16 17 18 19]
#   [20 21 22 23]]]
slice_tup = tuple(map(lambda i: slice(i, i+2), [1, 2]))
print(slice_tup)
# (slice(1, 3, None), slice(2, 4, None))
print(arr[slice_tup])
# [[[20 21 22 23]]]

# arr[..., slice_list]
print(arr[(Ellipsis, *slice_tup)])
# [[[ 6  7]
#   [10 11]]
# 
#  [[18 19]
#   [22 23]]]

# arr[..., slice_list, :]
print(arr[(Ellipsis, *slice_tup, slice(None))])
# [[[20 21 22 23]]]

# arr[:, slice_list, :]
print(arr[(slice(None), *slice_tup, slice(None))])
# IndexError: too many indices for array
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...