Я получаю сообщение об ошибке «Не удается различить тип w.r.t.» при использовании функции autograd в python.
По сути, я пытаюсь написать код для обобщенной линейной модели (GLM), и я хочу использовать autograd, чтобы получить функцию, описывающую производную функции потерь по w (весам), которую я бы затем подключить к scipy.optimize.minimize ().
Прежде чем выполнить шаг scipy, я пытался проверить, работает ли моя функция, вводя значения для переменных (которые в моем случае являются массивами) и печатая значение (снова в виде массива) для градиента в качестве выходных данных. Вот мой код:
def generate_data(n,k,m):
w = np.zeros((k,1)) # make first column of weights all zeros
w[:,[0]] = np.random.randint(-10, high=10,size=(k,m)) # choose length random inputs between -10 and 10
x = np.random.randint(-10, high=10,size=(n,m)) # choose length random inputs between -10 and 10
return x,w
def logpyx(x,w):
p = np.exp(np.dot(x,w.T)) # get exponentials e^wTx
norm = np.sum(p,axis=1) # get normalization constant (sum of exponentials)
pnorm = np.divide(p.T,norm).T # normalize the exponentials
ind = [] # initialize empty list
for n in np.arange(0,len(x)):
ind.append(np.random.choice(len(w),p = pnorm[n,:])) # choose index where y = 1 based on probabilities
ind = np.array(ind) # recast list as array
ys = [] # initialize empty list
for n in np.arange(0,len(x)):
y = [0] * (len(w)-1) # initialize list of zeros
y.insert(ind[n],1) # assign value "1" to appropriate index in row
ys.append(y) # add row to matrix of ys
y = np.array(ys) # recast list as array
pyx = np.diagonal(np.dot(pnorm,y.T)) # p(y|x)
log_pyx = np.log(pyx)
return log_pyx
# input data
n = 100 # number of data points
C = 2 # number of classes (e.g. turn right, turn left, move forward)
m = 1 # number of features in x (e.g. m = 2 for # of left trials and # of right trials)
log_pyx = logpyx(x,w) # calculate log likelihoods
grad_logpyx = grad(logpyx) # take gradient of log_pyx to find updated weights
x,w = generate_data(n,C,m)
print(grad_logpyx(x,w))
Так что, когда я делаю это, все работает нормально до последней строки, где я получаю ошибку, упомянутую ранее.
Я явно не понимаю, как использовать autograd очень хорошо, и я должен поместить что-то в неправильный формат, поскольку ошибка, по-видимому, связана с несоответствием типов данных. Любая помощь будет принята с благодарностью!