Векторизация функции - неверный индекс для скалярной переменной - PullRequest
1 голос
/ 04 апреля 2020

Я новичок в python. Я нашел интересную статью о векторизации, поэтому я начал изучать ее. Хотя я могу сделать это:

 def cost(a, b):
    "Return a-b if a>b, otherwise return a+b"
    if a > b:
        return a - b
    else:
        return a + b

cost_vector= np.vectorize(cost)
print(z([1,2,3],[3,4,5]))

вывод: [4 6 8]

Я не могу сделать это:

ww = [[1,2,3,4,5,6],[2,2,3,4,5,6],[3,2,3,4,5,6],[4,2,3,4,5,6],[5,2,3,4,5,6],[6,2,3,4,5,6]]

def cost(ww, a, b):
    if a > b:
        return ww[a][b]
    else:
        return ww[b][a]

z = np.vectorize(cost)
print(z(ww, [1,2,3], [3,4,5]))

output: IndexError: invalid index to scalar variable.

Я не могу понять, как сделать это сопоставление с моим массивом

Спасибо

1 Ответ

1 голос
/ 04 апреля 2020

Проблема с вашим кодом в том, что np.vectorize() пытается разложить все аргументы, включая ww. Согласно документации вам необходимо исключить ее с помощью параметра exclude, например:

import numpy as np


ww = [[1, 2, 3, 4, 5, 6], [2, 2, 3, 4, 5, 6], [3, 2, 3, 4, 5, 6],
      [4, 2, 3, 4, 5, 6], [5, 2, 3, 4, 5, 6], [6, 2, 3, 4, 5, 6]]


def cost(ww, a, b):
    if a > b:
        return ww[a][b]
    else:
        return ww[b][a]


v_cost = np.vectorize(cost, excluded={0})
print(v_cost(ww, [1, 2, 3], [3, 4, 5]))
# [2 3 4]

Обратите внимание, что вы можете сделать это в NumPy без необходимости np.vectorize() -украшенная функция. Вам просто нужно убедиться, что ww является массивом NumPy и использовать np.where() дважды:

import numpy as np


ww = [[1, 2, 3, 4, 5, 6], [2, 2, 3, 4, 5, 6], [3, 2, 3, 4, 5, 6],
      [4, 2, 3, 4, 5, 6], [5, 2, 3, 4, 5, 6], [6, 2, 3, 4, 5, 6]]


def cost(ww, a, b):
    return np.array(ww)[np.where(a > b, a, b), np.where(a > b, b, a)]


print(cost(ww, [1, 2, 3], [3, 4, 5]))
# [2 3 4]
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...