В некоторой точке моего кода мне нужно вычислить срок потери для каждого элемента моей партии отдельно, и после этого я хочу использовать tf.reduce_mean()
. После небольшого исследования я узнал о tf.while_loop()
, но, поскольку я новичок, я не могу понять, как правильно его использовать. Вот код, который я пробовал:
def loss_term2(x,mu,sigma):
i = tf.constant(0)
def condition(i):
return tf.less(i, 256)
def body(i):
x1 = tf.gather(x, i)
mu1 = tf.gather(mu, i)
sigma1 = tf.gather(sigma, i)
sigma1 = sigma1 + tf.eye(tf.shape(sigma1)[1])
sigmainv = tf.linalg.inv(sigma1)
prob = tf.matmul((x1-mu1), sigmainv)
prob = tf.matmul(prob, tf.linalg.transpose(x1-mu1))
return prob, i+1
prob , i = tf.while_loop(condition, body, i)
prob = tf.reduce_mean(prob)
return prob
x
, mu
и sigma
имеют одинаковые формы [batch_size=256, features=124]
. Я надеюсь, что мой код достаточно читабелен, чтобы понять мои намерения. Есть ли какое-нибудь решение сделать это с tf.while_loop()
?
или без него