Полагаю, ваша функция потерь действительно работает хорошо. Значение nan
, вероятно, исходит из прогнозов. Таким образом, условие tf.is_nan(y_actual)
не фильтрует его. Чтобы отфильтровать прогноз nan
, необходимо сделать следующее:
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1.keras import backend as K
import numpy as np
def nan_mse(y_actual, y_predicted):
stack = tf.stack((tf.is_nan(y_actual),
tf.is_nan(y_predicted)),
axis=1)
is_nans = K.any(stack, axis=1)
per_instance = tf.where(is_nans,
tf.zeros_like(y_actual),
tf.square(tf.subtract(y_predicted, y_actual)))
print(per_instance)
return tf.reduce_mean(per_instance, axis=0)
print(nan_mse([1.,1.,np.nan,1.,0.], [1.,1.,0.,0.,np.nan]))
Out:
tf.Tensor(0.2, shape=(), dtype=float32)