РЕДАКТИРОВАТЬ: Оказывается, что ниже не отвечает OP, поскольку он не обеспечивает решение для отслеживания градиентов для обратного распространения. Все еще оставляя это, поскольку это может использоваться как часть решения.
Один из способов - преобразовать тензор в массив numpy и использовать интерполяцию scipy, например, scipy.interpolate.LinearGridInterpolator
[1] или другие возможные numpy опции интерполяции массива ( some подробно здесь ). Не уверен, что это помогает, так как это не Pytorch + может включать копирование тензора вокруг.
Поскольку интерполяция scipy может быть медленной, одним из возможных решений является использование только пикселей, смежных с пропущенными значениями, для интерполяции (может быть легко получено путем расширения по маске пропущенных значений). Я думаю, что это может ускорить процесс на порядок, углубившись в тензорные измерения и количество пропущенных значений.
Редактировать: реализовано, в моем случае ускорение на два порядка.
def fillMissingValues(target_for_interp, copy=True,
interpolator=scipy.interpolate.LinearNDInterpolator):
import cv2, scipy, numpy as np
if copy:
target_for_interp = target_for_interp.copy()
def getPixelsForInterp(img):
"""
Calculates a mask of pixels neighboring invalid values -
to use for interpolation.
"""
# mask invalid pixels
invalid_mask = np.isnan(img) + (img == 0)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
#dilate to mark borders around invalid regions
dilated_mask = cv2.dilate(invalid_mask.astype('uint8'), kernel,
borderType=cv2.BORDER_CONSTANT, borderValue=int(0))
# pixelwise "and" with valid pixel mask (~invalid_mask)
masked_for_interp = dilated_mask * ~invalid_mask
return masked_for_interp.astype('bool'), invalid_mask
# Mask pixels for interpolation
mask_for_interp, invalid_mask = getPixelsForInterp(target_for_interp)
# Interpolate only holes, only using these pixels
points = np.argwhere(mask_for_interp)
values = target_for_interp[mask_for_interp]
interp = interpolator(points, values)
target_for_interp[invalid_mask] = interp(np.argwhere(invalid_mask))
return target_for_interp
# For the target tensor:
target_filled = fillMissingValues(target.numpy().squeeze())
# transform back to tensor etc..
Обратите внимание, что интерполированные значения будут np.nan
вне выпуклой оболочки действительной баллы, как указано LinearNDInterpolator
.