Как я могу векторизовать эту (numpy) операцию в python? - PullRequest
0 голосов
/ 17 марта 2019

У меня есть два вектора формы (batch, dim), которые я пытаюсь вычесть друг из друга.В настоящее время я использую простой цикл для вычитания определенной записи в векторе (т.е. error) на основе второго вектора (т.е. label) из 1:

per_ts_loss=0
for i, idx in enumerate(np.argmax(label, axis=1)):
    error[i, idx] -=1
    per_ts_loss += error[i, idx]

Как я могу векторизовать это?

Например, ошибка и метка могут выглядеть следующим образом:

error :
array([[ 0.5488135   0.71518937  0.60276338  0.54488318  0.4236548 ]
       [ 0.64589411  0.43758721  0.891773    0.96366276  0.38344152]])
label:
    array([[0, 0, 0, 1, 0 ],
           [0, 1, 0, 0, 0]])

для этого примера выполнение приведенного ниже кода приводит к следующим результатам:

for i, idx in enumerate(np.argmax(label,axis=1)):
    error[i,idx] -=1
    ls_loss += error[i,idx]

результат:

error: 
 [[ 0.5488135   0.71518937  0.60276338  0.54488318  0.4236548 ]
 [ 0.64589411  0.43758721  0.891773    0.96366276  0.38344152]]
label: 
 [[ 0.  0.  0.  1.  0.]
 [ 0.  1.  0.  0.  0.]]

error(indexes 3 and 1 are changed): 
[[ 0.5488135   0.71518937  0.60276338 -0.45511682  0.4236548 ]
 [ 0.64589411 -0.56241279  0.891773    0.96366276  0.38344152]]
per_ts_loss: 
 -1.01752960574

Вот сам код: https://ideone.com/e1k8ra

Я застрял на том, как использовать результат np.argmax, так как результатом является новый векториндексов, и его нельзя просто использовать как:

 error[:, np.argmax(label, axis=1)] -=1

Так что я застрял здесь!

Ответы [ 2 ]

1 голос
/ 17 марта 2019

Заменить:

error[:, np.argmax(label, axis=1)] -=1

с:

error[np.arange(error.shape[0]), np.argmax(label, axis=1)] -=1

и, конечно,

loss = error[np.arange(error.shape[0]), np.argmax(label, axis=1)].sum()

В вашем примере вы меняете и суммируете error[0,3] и error[1,1], или коротко error[[0,1],[3,1]].

0 голосов
/ 17 марта 2019

Может быть, это:

import numpy as np


error = np.array([[0.32783139, 0.29204386, 0.0572163 , 0.96162543, 0.8343454 ],
       [0.67308787, 0.27715222, 0.11738748, 0.091061  , 0.51806117]])

label= np.array([[0, 0, 0, 1, 0 ],
           [0, 1, 0, 0, 0]])



def f(error, label):
    per_ts_loss=0
    t=np.zeros(error.shape)
    argma=np.argmax(label, axis=1)
    t[[i for i in range(error.shape[0])],argma]=-1
    print(t)
    error+=t
    per_ts_loss += error[[i for i in range(error.shape[0])],argma]


f(error, label)

Ouput:

[[ 0.  0.  0. -1.  0.]
 [ 0. -1.  0.  0.  0.]]
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...