Получите градиент тензорного потока для необычной пользовательской функции потерь, заполняющей матрицу - PullRequest
0 голосов
/ 26 мая 2020

Я написал пользовательскую функцию потерь в тензорном потоке, но операции в моей потере нелегко дифференцировать. Это задача компьютерного зрения с использованием CNN. Результатом моей сети является вектор, определяющий некоторые контрольные точки для сплайн-представления кривой. Однако наземные метки истинности представляют собой матрицы, представляющие изображения. Идея функции потерь состоит в том, чтобы использовать контрольные точки, выведенные сетью, для «рисования» кривой для создания предсказанного изображения, а затем использовать некоторый традиционный метод для сравнения предсказанного изображения с наземным истинным изображением.

Обычно я понимаю, как tf.GradientTape работает для простых tf-операций, но я не могу заставить его работать для более сложных преобразований, происходящих в моей функции потерь. Пример кода:

# p is the output from the vector
p = tf.Variable(tf.constant([[1.0, 2.0],[3.0, 4.0], [5.0, 6.0], [1.0, 2.0]], dtype='float32'))
d = 2 #degree of spline
rows, cols = 256, 256 #dimensions of images
n = p.shape[0]
t = tf.linspace(0.0, 1.0, n+d+1) #steps for calc bspline

with tf.GradientTape() as tape:
    tape.watch(p)
    spline = BSpline(p, d, t) #returns a matrix of coordinates drawing the curve
    img_pred = tf.Variable(tf.zeros((256, 256)))
    for element in spline:
        img_pred[element].assign(1.0)

print(tape.gradient(img_pred, p)

Код здесь следует рассматривать как псевдокод. Моя основная проблема заключается в отслеживании градиента того, как изменения в сетевом выходе p влияют на окончательное предсказанное изображение img_pred, когда операции (насколько мне известно) не могут быть записаны в простых функциях тензорного потока

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...