Как получить массив индексов минимальных значений для многомерного массива? - PullRequest
1 голос
/ 28 июня 2019

У меня есть многомерный массив с формой (2, 2, 3), подобный этому:

array([[[  0.64,   0.49,   2.56],
    [  7.84,  13.69,  21.16]],

   [[ 33.64,  44.89,  57.76],
    [ 77.44,  94.09, 112.36]]])

Я бы хотел найти индексы мин для каждой строки. Таким образом, для этого примера есть 4 минимума: 0,49, 7,84, 33,64 и 77,44.

Чтобы получить индексы того минимума, я думал, что это сработает:

idx_arr = np.unravel_index(np.argmin(my_array,axis=2),my_array.shape)

Это дает следующий массив индексов:

(array([[0, 0],
    [0, 0]]), array([[0, 0],
    [0, 0]]), array([[1, 0],
    [0, 0]]))

Однако минимумы рассчитаны неправильно, как можно видеть:

my_array[idx_arr]
array([[0.49, 0.64],
   [0.64, 0.64]])

Что мне там не хватает?

1 Ответ

1 голос
/ 28 июня 2019

argmin действительно правильно вычисляет значения.Но вы неправильно понимаете, что ожидает np.unravel_index.

Из документов:

Преобразует плоский индекс или массив плоских индексов в набор координатных массивов.

Чтобы увидеть, какой тип ввода он примет, чтобы получить желаемый результат здесь, нам нужно сосредоточиться на главном: он преобразует плоский массив в правильный массив координат для определенного местоположения внеплоские условия.По сути, то, что ожидалось, это координаты ваших желаемых точек , как если бы ваш входной массив был сглажен.

import numpy as np
inp = np.array([[[  0.64,   0.49,   2.56],
    [  7.84,  13.69,  21.16]],

   [[ 33.64,  44.89,  57.76],
    [ 77.44,  94.09, 112.36]]])

idx = inp.argmin(axis=-1)
#Output:
array([[1, 0],
       [0, 0]], dtype=int64)

Обратите внимание, что вы не можете напрямую отправить этот idx, поскольку он не представляет правильные координаты для плоской версии массива inp.

Это будет выглядеть следующим образом:

flat_idx = np.arange(0, idx.size*inp.shape[-1], inp.shape[-1]) + idx.flatten()
#Output:
array([1, 3, 6, 9], dtype=int64)

И мы видим, что unravel_index с радостью принимает его.

temp = np.unravel_index(flat_idx, inp.shape)
#Output:
(array([0, 0, 1, 1], dtype=int64),
 array([0, 1, 0, 1], dtype=int64),
 array([1, 0, 0, 0], dtype=int64))

inp[temp]

Вывод:

array([ 0.49,  7.84, 33.64, 77.44])

Также, посмотрев на выходной кортеж,мы можем заметить, что это не так уж сложно воссоздать и самих себя.Обратите внимание, что последний массив соответствует сплюснутой форме idx, в то время как первые два массива по существу позволяют индексировать по первым двум осям inp.


И чтобы подготовить это, мы можем фактическииспользуйте функцию unravel_index довольно изящно, как показано ниже:

real_idx = (*np.unravel_index(np.arange(idx.size), idx.shape), idx.flatten())
inp[real_idx]
#Output:
array([ 0.49,  7.84, 33.64, 77.44])
...