Как я могу определить глобальную переменную в theano? Могу ли я использовать глобальную переменную, чтобы определить ее? - PullRequest
0 голосов
/ 22 июня 2019

Я хочу изменить свой код на форму цикла в theano, используя функцию theano.scan, но результат не тот, о котором я думал. Я определил глобальную переменную под названием tau_mu, я думаю, что это может привести к проблеме, но я не могу понять это.

I want to change the below code to the form of loop:

(1)

import theano
import theano.tensor as TT
t=np.arange(10)
tau = TT.vector('tau')
mu=TT.vector('mu')
tau_mu1=TT.switch(tau[0]>=t,mu[0],mu[1])
tau_mu2=TT.switch(tau[1]>=t,tau_mu1,mu[2])
tau_mu3=TT.switch(tau[2]>=t,tau_mu2,mu[3])
f = theano.function([tau,mu], tau_mu3)
f([3,5,7],[2,4,5,6])

результат: массив ([2., 2., 2., 2., 4., 4., 5., 5., 6., 6., 6., 6., 6., 6., 6.])

(2)the following is the form of loop
tau = TT.vector('tau')
mu=TT.vector('mu')
tau_mu=TT.vector('tau_mu')
tau_mu=TT.switch(tau[0]>=t,mu[0],mu[1])
indc=TT.ivector('indc')
def one_step(indc,tau,mu):
    global tau_mu
    tau_mu=TT.switch(tau[indc]>=t,tau_mu,mu[indc+1])
    return tau_mu
result,updates=theano.scan(fn=one_step,sequences=[indc],non_sequences=[tau,mu])
f = theano.function([indc,tau,mu], result)

f([1,2],[3,5,7],[2,4,5,6])

результат:

array([[2., 2., 2., 2., 4., 4., 5., 5., 5., 5., 5., 5., 5., 5., 5.],
       [2., 2., 2., 2., 4., 4., 4., 4., 6., 6., 6., 6., 6., 6., 6.]])

окончательный результат

array([2., 2., 2., 2., 4., 4., 4., 4., 6., 6., 6., 6., 6., 6., 6.]]),

мой ожидаемый результат

array([2., 2., 2., 2., 4., 4., 5., 5., 6., 6., 6., 6., 6., 6., 6.])

где я должен изменить свой код? Заранее благодарим за любые рекомендации, которые вы можете предоставить.

...