где в обычном потере - PullRequest
       9

где в обычном потере

0 голосов
/ 11 апреля 2019

Я пытаюсь определить свою собственную потерю, я реализую Гиперболическое расстояние в диске Пуанкаре, которое не определено, если одна из точек имеет радиус> = 1.

Поэтому я хочу использоватьtf.where, чтобы избежать значений nan и вернуть mse вместо nan.

Вот мой код:

def hyperbolic_loss(y_true, y_pred):    
    num = 2 * tf.norm(y_true - y_pred, axis = 1)
    densx = (1 - tf.norm(y_pred, axis = 1))
    dendx = (1 - tf.norm(y_true, axis = 1))
    frac = num/(densx * dendx)
    acos = tf.math.acosh(1  + frac)
    acos = tf.diag(acos)
    ret = K.mean(K.mean(acos, axis = -1))
    mse = K.mean(mean_squared_error(y_true, y_pred))

    return tf.where(
        tf.reduce_any(tf.is_nan(acos)),
        mse * 10, 
        ret)

Я пробовал его вне сети, и он работает с любым входом (дажеесли есть точки с радиусом> = 1), но когда я использую функцию в качестве функции потерь в сети, может случиться так, что будет возвращено nan.

Расстояние сообщается здесь , r равно 1.

Я думаю, что проблема в использовании tf.where, но я не уверен.

Как мне сделать, чтобы избежать nan возврата?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...