Numpy Where () с All () на двумерной матрице - PullRequest
0 голосов
/ 05 мая 2018
A= np.random.randint(5, size=(25, 4, 4))
U= np.unique(A, axis =0 )
results = np.where((A==U[0]).all(axis=-1))

Используя это, где функция сопоставляет отдельные строки, я хотел бы сопоставить весь массив 4x4, а не только отдельные строки.

вот пример результатов: (массив ([1, 97, 97, 97, 97], dtype = int64), массив ([0, 0, 1, 2, 3], dtype = int64))

Если бы все четыре строки были сопоставлены, результаты содержали бы тот же индекс 4 раза, что и для индекса 97 выше, одна строка была сопоставлена ​​с индексом "1".

Я предполагаю, что если бы весь массив был сопоставлен, то был бы возвращен только один индекс. Пример желаемого вывода, если для одного массива предоставлено несколько индексов: (массив ([97, 97, 97, 97], dtype = int64), массив ([0, 1, 2, 3], dtype = int64)

1 Ответ

0 голосов
/ 05 мая 2018
np.where((A.reshape(A.shape[0],-1) == U[0].reshape(-1)).all(axis=1))

Давайте рассмотрим пример

>>> A = np.random.randint(5, size=(25, 4, 4))
>>> A[:3,...]
array([[[0, 2, 0, 1],
        [1, 0, 3, 0],
        [4, 1, 1, 2],
        [0, 1, 0, 0]],

       [[1, 3, 2, 3],
        [2, 4, 2, 1],
        [3, 3, 2, 3],
        [4, 2, 1, 1]],

       [[4, 0, 3, 3],
        [1, 0, 4, 4],
        [0, 0, 2, 3],
        [4, 1, 2, 2]]])
>>> U = np.unique(A, axis=0)
>>> U[0]
array([[0, 2, 0, 1],
       [1, 0, 3, 0],
       [4, 1, 1, 2],
       [0, 1, 0, 0]])

Теперь вам нужно найти U[0] в A, если я правильно понял. Проще сопоставлять строку за строкой, поэтому давайте изменим массивы 4x4 на строки

>>> A.reshape(A.shape[0], -1)[:3,...]
array([[0, 2, 0, 1, 1, 0, 3, 0, 4, 1, 1, 2, 0, 1, 0, 0],
       [1, 3, 2, 3, 2, 4, 2, 1, 3, 3, 2, 3, 4, 2, 1, 1],
       [4, 0, 3, 3, 1, 0, 4, 4, 0, 0, 2, 3, 4, 1, 2, 2]])
>>> U[0].reshape(-1)
array([0, 2, 0, 1, 1, 0, 3, 0, 4, 1, 1, 2, 0, 1, 0, 0])

Теперь мы можем сравнить их с np.where, но если мы не будем осторожны, мы получим поэлементное сравнение, поэтому нам нужно использовать np.all(axis=1), чтобы убедиться, что они сравниваются строка за строкой:

>>> np.where(np.all(A.reshape(25, -1) == U[0].reshape(-1), axis=1))
(array([0]),)

РЕДАКТИРОВАТЬ мне пришло в голову, что вы можете использовать несколько осей с np.all и вообще избежать изменения формы:

np.where((A == U[0]).all(axis=(1,2)))
...