Я пытаюсь определить свою собственную потерю, я реализую Гиперболическое расстояние в диске Пуанкаре, которое не определено, если одна из точек имеет радиус> = 1.
Поэтому я хочу использоватьtf.where, чтобы избежать значений nan и вернуть mse вместо nan.
Вот мой код:
def hyperbolic_loss(y_true, y_pred):
num = 2 * tf.norm(y_true - y_pred, axis = 1)
densx = (1 - tf.norm(y_pred, axis = 1))
dendx = (1 - tf.norm(y_true, axis = 1))
frac = num/(densx * dendx)
acos = tf.math.acosh(1 + frac)
acos = tf.diag(acos)
ret = K.mean(K.mean(acos, axis = -1))
mse = K.mean(mean_squared_error(y_true, y_pred))
return tf.where(
tf.reduce_any(tf.is_nan(acos)),
mse * 10,
ret)
Я пробовал его вне сети, и он работает с любым входом (дажеесли есть точки с радиусом> = 1), но когда я использую функцию в качестве функции потерь в сети, может случиться так, что будет возвращено nan
.
Расстояние сообщается здесь , r равно 1.
Я думаю, что проблема в использовании tf.where, но я не уверен.
Как мне сделать, чтобы избежать nan
возврата?