Я хочу изменить свой код на форму цикла в theano, используя функцию theano.scan, но результат не тот, о котором я думал. Я определил глобальную переменную под названием tau_mu, я думаю, что это может привести к проблеме, но я не могу понять это.
I want to change the below code to the form of loop:
(1)
import theano
import theano.tensor as TT
t=np.arange(10)
tau = TT.vector('tau')
mu=TT.vector('mu')
tau_mu1=TT.switch(tau[0]>=t,mu[0],mu[1])
tau_mu2=TT.switch(tau[1]>=t,tau_mu1,mu[2])
tau_mu3=TT.switch(tau[2]>=t,tau_mu2,mu[3])
f = theano.function([tau,mu], tau_mu3)
f([3,5,7],[2,4,5,6])
результат:
массив ([2., 2., 2., 2., 4., 4., 5., 5., 6., 6., 6., 6., 6., 6., 6.])
(2)the following is the form of loop
tau = TT.vector('tau')
mu=TT.vector('mu')
tau_mu=TT.vector('tau_mu')
tau_mu=TT.switch(tau[0]>=t,mu[0],mu[1])
indc=TT.ivector('indc')
def one_step(indc,tau,mu):
global tau_mu
tau_mu=TT.switch(tau[indc]>=t,tau_mu,mu[indc+1])
return tau_mu
result,updates=theano.scan(fn=one_step,sequences=[indc],non_sequences=[tau,mu])
f = theano.function([indc,tau,mu], result)
f([1,2],[3,5,7],[2,4,5,6])
результат:
array([[2., 2., 2., 2., 4., 4., 5., 5., 5., 5., 5., 5., 5., 5., 5.],
[2., 2., 2., 2., 4., 4., 4., 4., 6., 6., 6., 6., 6., 6., 6.]])
окончательный результат
array([2., 2., 2., 2., 4., 4., 4., 4., 6., 6., 6., 6., 6., 6., 6.]]),
мой ожидаемый результат
array([2., 2., 2., 2., 4., 4., 5., 5., 6., 6., 6., 6., 6., 6., 6.])
где я должен изменить свой код?
Заранее благодарим за любые рекомендации, которые вы можете предоставить.