Я хочу изменить типичную функцию потерь MSE. Прямо сейчас у меня есть следующий код:
squared_difference = tf.reduce_sum(tf.square(target - output), [1])
mse_loss = tf.reduce_mean(squared_difference)
форма обоих тензоров равна [batch_size, 10]
, а примером для цели является [0,1,2,3,0.5,0.5,0.5,7,8,9]
. 0.5
s всегда находятся в индексах 4, 5 и 6.
Что я хочу сделать сейчас, так это полностью игнорировать эти индексы и не увеличивать потери, если выходной сигнал сети не имеет 0,5 при этих индексах.
Таким образом, если выходное значение равно [0,1,2,3,20,10,14,7,8,9]
, потеря должна составлять 0
.
Каков наилучший способ достичь этого?