Этот код
@jit(nopython=True)
def _sgd_batch(learning_rate, w, xtrain, ytrain):
for i in range(xtrain.shape[1]):
wx = w*xtrain[:,i]
error = wx - ytrain[:,i]
for j in range(w.shape[0]):
w[j,:] = w[j,:]-learning_rate*error[j]*np.transpose(xtrain[:,i])
return w
вызывает исключение ValueError: unable to broadcast argument 1 to output array
при вызове со следующими параметрами: learning_rate
- скаляр, w
- матрица 3x1000, xtrain - матрица 1000 x 10k,и ytrain - это матрица 3 x 100 тыс.Когда я запускаю его без @jit
, он работает просто отлично.
Все поиски в Google по запросу numba "unable to broadcast argument"
выполняются либо по коду , который вызывает исключение (отмечу, что строка не включенатрассировка, но вмешательство с помощью отладчика подтверждает, что это именно эта строка) и проблема github, открывающая ошибку о том, что ошибка не возникла (без, к сожалению, особых причин, когдаэто надо поднять).Без кавычек в поиске в Google я получаю множество откликов на проблемы с бесшумной трансляцией, которых у меня нет.Я не могу понять смысл кода, который приводит к возникновению ошибки на этом этапе, поэтому я задаюсь вопросом: кто-нибудь сталкивался с этим раньше и знает, что означает эта ошибка?