Numpy вычислить Min Max в случайном массиве 2D или 1D - PullRequest
2 голосов
/ 05 августа 2020

Я пытаюсь получить сумму минимального или максимального значения в каждой строке. Если число в первом индексе больше 12, вернуть максимальное число в строке, иначе вернуть минимальное число в строке. Обратите внимание, что в приведенном ниже примере это 2D-массив 4 x 3. Однако я хочу, чтобы мой код работал при любом размере или форме массива.

import numpy as np

arr = np.array([[11, 12, 13],
                [14, 15, 16],
                [17, 15, 11],
                [12, 14, 15]])

i_max = np.amax(arr,axis=1)
i_min = np.amin(arr,axis=1)
print(i_max)
print(i_min)

Пока я могу получить только минимальное или максимальное число в каждой строке, используя amax и amin. Мне не хватает доступа к первому значению в каждой строке и использования оператора if else для сравнения размера с 12. Может ли кто-нибудь дать подсказку.

Правильный результат для данного образца должен быть 11 + 16 + 17 + 12 = 56

Ответы [ 2 ]

1 голос
/ 05 августа 2020

Вы почти у цели. Вы можете создать mask (как хотите, вот первый элемент каждой строки больше 12) и рассчитать свою операцию, как показано ниже. Его легко расширить до многомерного массива, изменив ось в min/max и желаемое условие mask:

mask=arr[:,0]>12
(arr.max(1)*mask + arr.min(1)*~mask).sum()

вывод:

56

Сравнение :

def m1(arr):
  mask=arr[:,0]>12
  return (arr.max(1)*mask+arr.min(1)*~mask).sum()

#@Dieter's solution
def m2(arr):
  return np.where(arr[:,0] > 12, np.max(arr, axis=1), np.min(arr, axis=1)).sum()
 
in_ = [np.random.randint(100, size=(n,n)) for n in [10,100,1000,10000]]

Время выполнения : m1 кажется немного быстрее, однако они сходятся к той же производительности в больших массивах . введите описание изображения здесь

1 голос
/ 05 августа 2020

np. Где содержит 3 входа . Первый - это ваше условие [False, True, True False ...] второй и третий входы, это возможные значения .

Таким образом, если условие истинно, grep значение второго входа в противном случае возьмите значение третьего входа.

np.where(arr[:,0] > 12, np.max(arr, axis=1), np.min(arr, axis=1))

return: array([11, 16, 17, 12])

и, если вам нужна сумма, просто добавьте сумму :):

np.where(arr[:,0] > 12, np.max(arr, axis=1), np.min(arr, axis=1)).sum()

return: 56

другой способ представления: [xv if c else yv for c, xv, yv in zip(condition, x, y)]

или просто прочтите документацию: https://numpy.org/doc/stable/reference/generated/numpy.where.html:)

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...