Почему разделение с использованием «двоеточия и запятой» отличается от использования набора индексов - PullRequest
0 голосов
/ 30 мая 2018

Почему нарезка с использованием «двоеточия и запятой» отличается от использования коллекции индексов?

Вот пример того, что я ожидал получить тот же результат, но это не так:

import numpy as np

a = np.array([[[1,2,3],[4,5,6]],[[7,8,9],[10,11,12]]])

print(a[[0,1],[0,1]])
# Output
# [[ 1  2  3]
#  [10 11 12]]

print(a[:,[0,1]])
# Output
# [[[ 1  2  3]
#   [ 4  5  6]]
#  [[ 7  8  9]
#   [10 11 12]]]

Почему они не эквивалентны?

1 Ответ

0 голосов
/ 30 мая 2018

В первом случае вы индексируете массив a с 2 списками одинаковой длины, что эквивалентно индексации с 2 массивами одинаковой формы (см. numpy docs для массивов как индексы ).

Следовательно, на выходе получаются a[0,0] (что совпадает с a[0,0,:]) и a[1,1], поэлементные комбинации массива индекса.Ожидается, что это вернет массив формы 2,3.2 потому что это длина индексного массива, а 3 потому что это ось, которая не индексируется.

Однако во втором случае результат равен a[:,0] (эквивалентно a[:,0,:]) и a[:,1].Таким образом, здесь ожидаемый результат представляет собой массив с первым и третьим измерениями, эквивалентными исходному массиву, а второе измерение, равное 2, является длиной индексного массива (который здесь равен исходному размеру второй оси)..

Чтобы ясно показать, что эти две операции явно не совпадают, мы можем попытаться предположить эквивалентность между : и диапазоном, равным длине от оси до третьей оси, что приведет к:

print(a[[0,1],[0,1],[0,1,2]])
IndexError                                Traceback (most recent call last)
<ipython-input-8-110de8f5f6d8> in <module>()
----> 1 print(a[[0,1],[0,1],[0,1,2]])

IndexError: shape mismatch: indexing arrays could not be broadcast together with shapes (2,) (2,) (3,) 

Это связано с тем, что элементная комбинация индексных массивов невозможна.В противоположность этому, a[:,:,:] вернет весь массив, а a[[0,1],[0,1],[0,2]] вернет [ 1 12], который, как и ожидалось, представляет собой массив одного измерения длиной 2, как массив индексов.

...