Вот моя реализация (я беру пример логитов измерения [3,5]):
Версия Tensorflow:
import tensorflow as tf
def kl_loss_compute(logits1, logits2):
""" KL loss
"""
pred1 = tf.nn.softmax(logits1)
print(pred1.eval())
pred2 = tf.nn.softmax(logits2)
print(pred2.eval())
loss = tf.reduce_mean(tf.reduce_sum(pred2 * tf.log(1e-8 + pred2 / (pred1 + 1e-8)), 1))
return loss
x1 = tf.random.normal([3, 5], dtype=tf.float32)
x2 = tf.random.normal([3, 5], dtype=tf.float32)
with tf.Session() as sess:
x1 = sess.run(x1)
print(x1)
x2 = sess.run(x2)
print(x2)
print(30*'=')
print(sess.run(kl_loss_compute(x1, x2)))
Выход:
[[ 0.9801388 -0.2514422 -0.28299806 0.85130763 0.4565948 ]
[-1.0744809 0.20301117 0.21026622 1.0385195 0.41147012]
[ 1.2385081 1.1003486 -2.0818367 -1.0446491 1.8817908 ]]
[[ 0.04036871 0.82306993 0.82962424 0.5209219 -0.10473887]
[ 1.7777447 -0.6257034 -0.68985045 -1.1191329 -0.2600192 ]
[ 0.03387258 0.44405013 0.08010675 0.9131149 0.6422863 ]]
==============================
[[0.32828477 0.09580362 0.09282765 0.2886025 0.19448158]
[0.04786159 0.17170973 0.17296004 0.39596024 0.21150835]
[0.2556382 0.22265059 0.00923886 0.02606533 0.48640704]]
[[0.12704821 0.27790183 0.27972925 0.20543297 0.10988771]
[0.7349108 0.06644011 0.062312 0.04056362 0.09577343]
[0.12818882 0.19319147 0.13425465 0.30881628 0.23554876]]
0.96658206
Версия PyTorch:
def kl_loss_compute(logits1, logits2):
""" KL loss
"""
pred1 = torch.softmax(logits1, dim=-1, dtype=torch.float32)
print(pred1)
pred2 = torch.softmax(logits2, dim=-1, dtype=torch.float32)
print(pred2)
loss = torch.mean(torch.sum(pred2 * torch.log(1e-8 + pred2 / (pred1 + 1e-8)), -1))
return loss
# same inputs are used here as above(see the inputs used in tensorflow code in the output)
x = torch.Tensor([[ 0.9801388, -0.2514422 , -0.28299806 , 0.85130763, 0.4565948 ],
[-1.0744809 , 0.20301117, 0.21026622, 1.0385195, 0.41147012],
[ 1.2385081 , 1.1003486, -2.0818367, -1.0446491, 1.8817908 ]])
y = torch.Tensor([[ 0.04036871 , 0.82306993, 0.82962424, 0.5209219, -0.10473887],
[ 1.7777447 ,-0.6257034, -0.68985045, -1.1191329, -0.2600192 ],
[ 0.03387258 , 0.44405013 , 0.08010675, 0.9131149, 0.6422863 ]])
print(kl_loss_compute(x, y))
Выход:
tensor([[0.3283, 0.0958, 0.0928, 0.2886, 0.1945],
[0.0479, 0.1717, 0.1730, 0.3960, 0.2115],
[0.2556, 0.2227, 0.0092, 0.0261, 0.4864]])
tensor([[0.1270, 0.2779, 0.2797, 0.2054, 0.1099],
[0.7349, 0.0664, 0.0623, 0.0406, 0.0958],
[0.1282, 0.1932, 0.1343, 0.3088, 0.2355]])
tensor(0.9666)