Я не уверен, действительно ли это то, что вы хотите, или что-то, что вы можете использовать, но почти как своего рода упражнение я собрал версию алгоритма TensorFlow, которую вы там поместили. Я сделал это для произвольных карт смещения по вертикальным и горизонтальным координатам, так что ваша проблема была бы подзаголовком. Он немного отличается от вашего алгоритма, здесь, если два пикселя перемещаются по одной и той же смещенной координате, их значения агрегируются, а не просто берут последний. Кроме того, этот алгоритм использует билинейную интерполяцию для вычисления дробных смещений, что делает его дифференцируемым.
import tensorflow as tf
def image_displace_tf(img, disp):
# Get image shape
s = tf.shape(img, out_type=tf.int64)
# Image shape as floats
sf = tf.dtypes.cast(s, disp.dtype)
# Make coordinate grid
bb, ii, jj = tf.meshgrid(tf.range(s[0]), tf.range(s[1]), tf.range(s[2]),
indexing='ij')
# Compute displaced coordinates
coords = tf.stack([bb, ii, jj], axis=-1)
# Add a null "batch displacement"
disp_pad = tf.pad(disp, [[0, 0], [0, 0], [0, 0], [1, 0]])
coords_disp = tf.dtypes.cast(coords, disp.dtype) + disp_pad
# Mask displacements out of range
mask = ((coords_disp[..., 1] >= 0) & (coords_disp[..., 1] < sf[1] - 1) &
(coords_disp[..., 2] >= 0) & (coords_disp[..., 2] < sf[2] - 1))
coords = tf.boolean_mask(coords, mask)
coords_disp = tf.boolean_mask(coords_disp, mask)
# Compute interpolation alpha values for bilinear interpolation
alpha_1 = tf.math.floormod(coords_disp, 1.0)
alpha_1 = tf.expand_dims(alpha_1, axis=-1)
alpha_0 = 1 - alpha_1
alpha_00 = alpha_0[..., 1, :] * alpha_0[..., 2, :]
alpha_01 = alpha_0[..., 1, :] * alpha_1[..., 2, :]
alpha_10 = alpha_1[..., 1, :] * alpha_0[..., 2, :]
alpha_11 = alpha_1[..., 1, :] * alpha_1[..., 2, :]
# Begin and end indices for each dimension
idx_00 = tf.dtypes.cast(coords_disp, tf.int64)
idx_01 = idx_00 + [0, 0, 1]
idx_10 = idx_00 + [0, 1, 0]
idx_11 = idx_00 + [0, 1, 1]
# Values at begin and end for each dimension scaled by their alpha values
img_coords = tf.gather_nd(img, coords)
value_00 = alpha_00 * img_coords
value_01 = alpha_01 * img_coords
value_10 = alpha_10 * img_coords
value_11 = alpha_11 * img_coords
# Concatenate all indices and values
idx_all = tf.concat([idx_00, idx_01, idx_10, idx_11], axis=0)
value_all = tf.concat([value_00, value_01, value_10, value_11], axis=0)
# Make aggregated result
return tf.scatter_nd(idx_all, value_all, s)
Я не уверен, что он будет работать, как вы надеетесь, но вы можете взять градиенты как изображения, так и карта смещения. Вот как это можно использовать:
import tensorflow as tf
import matplotlib.pyplot as plt
tf.random.set_seed(0)
plt.close('all')
# Make a radial image
x = tf.linspace(-1.0, 1.0, 400)
y = tf.linspace(-1.0, 1.0, 300)
r = tf.math.sqrt(tf.math.square(x) + tf.math.square(tf.expand_dims(y, 1)))
img = tf.math.maximum(0.8 - r, 0.0)
# Give it three channels
img = tf.tile(tf.expand_dims(img, axis=-1), [1, 1, 3])
# Add batch dimension
img = tf.expand_dims(img, axis=0)
# Show image
plt.figure()
plt.imshow(img.numpy()[0])
plt.title('Source image')
plt.show()
# Make a wavy displacement map along the horizontal axis
s = tf.shape(img)
disp_i = tf.zeros(s[:-1], dtype=img.dtype)
_, ii, jj = tf.meshgrid(*(tf.linspace(0.0, 1.0, si) for si in s[:3]), indexing='ij')
disp_j = 5.0 * (tf.sin(40.0 * jj + tf.sin(20.0 * ii)))
disp = tf.stack([disp_i, disp_j], axis=-1)
# Show map
plt.figure()
plt.imshow(0.5 + 0.5 * disp.numpy()[0, ..., 1])
plt.title('Horizontal displacement map')
plt.show()
# Do displacement
with tf.GradientTape() as g:
g.watch(img)
g.watch(disp)
img_disp = image_displace_tf(img, disp)
# Some minimization goal - e.g. sum of squared pixel values
goal = tf.math.reduce_sum(tf.square(img_disp))
# Show result
plt.figure()
plt.imshow(img_disp.numpy()[0])
plt.title('Displaced image')
plt.show()
# Show gradients
img_grad, disp_grad = g.gradient(goal, [img, disp])
# Gradient of image
plt.figure()
plt.imshow(img_grad.numpy()[0, ..., 0])
plt.title('Image gradient')
plt.show()
# Gradient of horizontal coorindate of displacement map
plt.figure()
plt.imshow(disp_grad.numpy()[0, ..., 1])
plt.title('Horizontal displacement gradient')
plt.show()
Выходные изображения:
![Horizontal displacement gradient](https://i.stack.imgur.com/cuD24.png)