Проблема с вашим кодом в том, что np.vectorize()
пытается разложить все аргументы, включая ww
. Согласно документации вам необходимо исключить ее с помощью параметра exclude
, например:
import numpy as np
ww = [[1, 2, 3, 4, 5, 6], [2, 2, 3, 4, 5, 6], [3, 2, 3, 4, 5, 6],
[4, 2, 3, 4, 5, 6], [5, 2, 3, 4, 5, 6], [6, 2, 3, 4, 5, 6]]
def cost(ww, a, b):
if a > b:
return ww[a][b]
else:
return ww[b][a]
v_cost = np.vectorize(cost, excluded={0})
print(v_cost(ww, [1, 2, 3], [3, 4, 5]))
# [2 3 4]
Обратите внимание, что вы можете сделать это в NumPy без необходимости np.vectorize()
-украшенная функция. Вам просто нужно убедиться, что ww
является массивом NumPy и использовать np.where()
дважды:
import numpy as np
ww = [[1, 2, 3, 4, 5, 6], [2, 2, 3, 4, 5, 6], [3, 2, 3, 4, 5, 6],
[4, 2, 3, 4, 5, 6], [5, 2, 3, 4, 5, 6], [6, 2, 3, 4, 5, 6]]
def cost(ww, a, b):
return np.array(ww)[np.where(a > b, a, b), np.where(a > b, b, a)]
print(cost(ww, [1, 2, 3], [3, 4, 5]))
# [2 3 4]