Как получить горизонтальный и вертикальный градиент разницы между y_true и y_pred? - PullRequest
0 голосов
/ 01 октября 2018

Я хочу определить пользовательскую функцию потерь, используя Keras, которая содержит градиент разницы между y_true и y_pred.Я обнаружил, что numpy.gradient может помочь мне получить градиент массива.Поэтому часть моего кода для функции потерь выглядит следующим образом:

def loss(y_true, y_pred):
    d   = y_true - y_pred
    gradient_x = np.gradient(d, axis=0)
    gradient_y = np.gradient(d, axis=1)

, но оказывается, d является тензорным классом Tensorflow и numpy.gradient не может его обработать.Я новичок в Keras и Tensorflow.

Есть ли какая-либо другая функция, которая может помочь мне сделать это?Или я должен сам рассчитать градиент?

1 Ответ

0 голосов
/ 01 октября 2018

Тензорные потоки вообще не являются массивами, когда они выполняются, они являются только ссылками на строящийся вычислительный граф.Возможно, вы захотите просмотреть учебник о том, как Tensorflow строит графики .

У вас есть две проблемы с вашей функцией потерь: во-первых, свертывание по любой оси не приведет к скаляру, поэтому он выиграл 'Невозможно взять производную, а во-вторых, np.gradient, по-видимому, не существует в Tensorflow.

Для первой проблемы вы можете решить ее, уменьшив вдоль оставшейся оси gradient_y или gradient_x.Я не знаю, какую функцию вы можете использовать, потому что я не знаю ваше приложение.

Вторую проблему можно решить двумя способами:

  1. Вы можете обернуть np.gradient с использованием py_func, но вы планируете использовать это как функцию потерь, так что вы захотите взять градиент этой функции, и определение градиента вызова py_func равно сложный .
  2. Напишите свою собственную версию np.gradient, используя чистый Tensorflow.

Например, вот 1D np.gradient в тензорном потоке ( не проверено ):

def gradient(x):
    d = x[1:]-x[:-1]
    fd = tf.concat([x,x[-1]], 0).expand_dims(1)
    bd = tf.concat([x[0],x], 0).expand_dims(1)
    d = tf.concat([fd,bd], 1)
    return tf.reduce_mean(d,1)
...