Почему numpy .vectorize () изменяет вывод деления скалярной функции? - PullRequest
1 голос
/ 01 марта 2020

Я получаю странный результат, когда векторизую функцию с numpy.

import numpy as np
def scalar_function(x, y):
    """ A function that returns x*y if x<y and x/y otherwise
    """
    if x < y :
        out = x * y 
    else:
        out = x/y 
    return out

def vector_function(x, y):
    """
    Make it possible to accept vectors as input
    """
    v_scalar_function = np.vectorize(scalar_function)
    return v_scalar_function(x, y)

у нас есть

scalar_function(4,3)
# 1.3333333333333333

Почему векторизованная версия дает такой странный вывод?

vector_function(np.array([3,4]), np.array([4,3]))
[12  1]

Хотя этот вызов векторизованной версии работает нормально:

vector_function(np.array([4,4]), np.array([4,3]))
[1.         1.33333333]

Чтение numpy .divide :

Примечания Оператор деления на пол // был добавлен в Python 2.2, создавая // и / или эквивалентные операторы. Операция деления на этаж по умолчанию / может быть заменена истинным делением с __future__ делением импорта. В Python 3.0, // является оператором разделения этажа и / истинным оператором деления. Функция true_divide (x1, x2) эквивалентна истинному делению в Python.

Мне кажется, это может быть остающаяся проблема, связанная с python2? Но я пользуюсь python 3!

Ответы [ 2 ]

3 голосов
/ 01 марта 2020

Документы для numpy.vectorize состояния:

Тип вывода определяется путем оценки первого элемента ввода, если он не указан

Поскольку вы не указали тип возвращаемых данных, а первый пример - целочисленное умножение, первый массив также имеет целочисленный тип и округляет значения. И наоборот, когда первой операцией является деление, тип данных автоматически преобразуется в плавающее. Вы можете исправить свой код, указав dtype в vector_function (который необязательно должен быть таким же большим, как 64-битный для этой проблемы):

def vector_function(x, y):
    """
    Make it possible to accept vectors as input
    """
    v_scalar_function = np.vectorize(scalar_function, otypes=[np.float64])
    return v_scalar_function(x, y)

Отдельно вы также должны сделать заметку из та же самая документация о том, что numpy.vectorize является вспомогательной функцией и в основном просто оборачивает Python for l oop, поэтому не векторизована в том смысле, что она обеспечивает какой-либо реальный прирост производительности.

Для бинарный выбор, подобный этому, лучший общий подход будет выглядеть так:

def vectorized_scalar_function(arr_1, arr_2):
    return np.where(arr_1 < arr_2, arr_1 * arr_2, arr_1 / arr_2)

print(vectorized_scalar_function(np.array([4,4]), np.array([4,3])))
print(vectorized_scalar_function(np.array([3,4]), np.array([4,3])))

Выше должно быть на несколько порядков быстрее и (возможно, по совпадению, а не по жесткому правилу, на которое можно положиться) не страдает проблема приведения типа для результата.

2 голосов
/ 01 марта 2020

Проверка того, какие состояния вызваны:

import numpy as np

def scalar_function(x, y):
    """ A function that returns x*y if x<y and x/y otherwise
    """
    if x < y :
        print('if x: ',x)
        print('if y: ',y)
        out = x * y 
        print('if out', out)
    else:
        print('else x: ',x)
        print('else y: ',y)
        out = x/y
        print('else out', out)

    return out

def vector_function(x, y):
    """
    Make it possible to accept vectors as input
    """
    v_scalar_function = np.vectorize(scalar_function)
    return v_scalar_function(x, y)


vector_function(np.array([3,4]), np.array([4,3]))

if x:  3
if y:  4
if out 12
if x:  3
if y:  4
if out 12
else x:  4
else y:  3
else out 1.3333333333333333 # <-- seems that the value is calculated correctly, but the wrong dtype is returned

Итак, вы можете переписать скалярную функцию:

def scalar_function(x, y):
    """ A function that returns x*y if x<y and x/y otherwise
    """
    if x < y :
        out = x * y 
    else:
        out = x/y
    return float(out)


vector_function(np.array([3,4]), np.array([4,3]))
array([12.        ,  1.33333333])
...