Numpy «Где» функция не может избежать оценки Sqrt (отрицательный) - PullRequest
0 голосов
/ 03 октября 2018

Кажется, что функция np.where сначала оценивает все возможные результаты, а затем оценивает состояние позже.Это означает, что в моем случае он будет вычислять квадратный корень из -5, -4, -3, -2, -1, даже если он не будет использоваться позже.

Мой код работает и работает.Но моя проблема - предупреждение.Я избегал использования цикла для оценки каждого элемента, потому что он будет работать намного медленнее, чем np.where.

Итак, здесь я спрашиваю

  1. Есть ли способ np.where сначала оценить состояние?
  2. Можно ли отключить только это конкретное предупреждение?Как?
  3. Еще один лучший способ сделать это, если у вас есть лучшее предложение.

Вот лишь краткий пример кода, соответствующего моему реальному коду, который является гигантским.Но по сути имеет ту же проблему.

Ввод:

import numpy as np

c=np.arange(10)-5
d=np.where(c>=0, np.sqrt(c) ,c )

Ввод:

RuntimeWarning: invalid value encountered in sqrt
d=np.where(c>=0,np.sqrt(c),c)

Ответы [ 4 ]

0 голосов
/ 03 октября 2018

np.sqrt является ufunc и принимает параметр where.В этом случае ее можно использовать в качестве маски:

In [61]: c = np.arange(10)-5.0
In [62]: d = c.copy()
In [63]: np.sqrt(c, where=c>=0, out=d);
In [64]: d
Out[64]: 
array([-5.        , -4.        , -3.        , -2.        , -1.        ,
        0.        ,  1.        ,  1.41421356,  1.73205081,  2.        ])

В отличие от случая np.where, это не оценивает функцию в элементах ~ where.

0 голосов
/ 03 октября 2018

Одним из решений является не использовать np.where, а вместо этого использовать индексацию.

c = np.arange(10)-5
d = c.copy()
c_positive = c > 0
d[c_positive] = np.sqrt(c[c_positive])
0 голосов
/ 03 октября 2018

Существует намного лучший способ сделать это.Давайте посмотрим, что делает ваш код, чтобы понять, почему.

np.where принимает три массива в качестве входных данных.Массивы не поддерживают отложенную оценку.

d = np.where(c >= 0, np.sqrt(c), c)

Следовательно, эта строка эквивалентна выполнению

a = (c >= 0)
b = np.sqrt(c)
d = np.where(a, b, c)

Обратите внимание, что входные данные вычисляются немедленно, прежде чем where когда-либо будет вызван.

К счастью, вам вообще не нужно использовать where.Вместо этого просто используйте логическую маску:

mask = (c >= 0)
d = np.empty_like(c)
d[mask] = np.sqrt(c[mask])
d[~mask] = c[~mask]

Если вы ожидаете много негативов, вы можете скопировать все элементы вместо только негативных:

d = c.copy()
d[mask] = np.sqrt(c[mask])

Еще лучшеРешением может быть использование замаскированных массивов:

d = np.ma(c, c < 0)
d = np.ma.sqrt(d)

Чтобы получить доступ ко всему массиву данных с неизмененной замаскированной частью, используйте d.data.

0 голосов
/ 03 октября 2018

Это ответ на ваш второй вопрос.

Да, вы можете отключить предупреждения.Используйте предупреждения модуль.

import warnings
warnings.filterwarnings("ignore")
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...