Как создать индекс Numpy с двоеточиями везде, кроме одной переменной координаты? - PullRequest
1 голос
/ 25 февраля 2020

Я хотел бы создать индекс в виде

[:, :, :, 0, :, :, :, :]

, где позиция 0 определяется переменной, скажем axis, чтобы разрезать массив NumPy. Очевидно, есть два особых случая, которые легко лечить:

  • axis = 0 будет эквивалентно [0, ...]
  • axis = -1 будет эквивалентно [..., 0]

Но мне интересно, как это можно сделать для любого axis значения?

1 Ответ

1 голос
/ 25 февраля 2020

Вы можете создать кортеж и использовать slice(None) вместо ::

def custom_index(arr, position, index):
    idx = [slice(None)] * len(arr.shape)
    idx[position] = index
    return arr[tuple(idx)]

Быстрый тест:

mat = np.random.random((5, 3))
assert np.all(mat[2, :] == custom_index(mat, 0, 2))  # mat[(2, slice(None))]
assert np.all(mat[:, 2] == custom_index(mat, 1, 2))  # mat[(slice(None), 2))]

РЕДАКТИРОВАТЬ: как Указанный в комментарии правильный путь: np.take

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...