Факел - интерполировать пропущенные значения - PullRequest
1 голос
/ 15 января 2020

У меня есть запас тензорных изображений в форме NumOfImagesxHxW, который включает нули. Я ищу способ интерполировать пропущенные значения (нули), используя информацию только на одном изображении (нет связи между изображениями). Есть ли способ сделать это с помощью pytorch?

Кажется, что F.interpolate работает только для изменения формы. Мне нужно заполнить нули, сохранив при этом размеры и градиенты тензора.

Спасибо.

1 Ответ

0 голосов
/ 19 января 2020

РЕДАКТИРОВАТЬ: Оказывается, что ниже не отвечает OP, поскольку он не обеспечивает решение для отслеживания градиентов для обратного распространения. Все еще оставляя это, поскольку это может использоваться как часть решения.

Один из способов - преобразовать тензор в массив numpy и использовать интерполяцию scipy, например, scipy.interpolate.LinearGridInterpolator [1] или другие возможные numpy опции интерполяции массива ( some подробно здесь ). Не уверен, что это помогает, так как это не Pytorch + может включать копирование тензора вокруг.

Поскольку интерполяция scipy может быть медленной, одним из возможных решений является использование только пикселей, смежных с пропущенными значениями, для интерполяции (может быть легко получено путем расширения по маске пропущенных значений). Я думаю, что это может ускорить процесс на порядок, углубившись в тензорные измерения и количество пропущенных значений.

Редактировать: реализовано, в моем случае ускорение на два порядка.

def fillMissingValues(target_for_interp, copy=True, 
                      interpolator=scipy.interpolate.LinearNDInterpolator): 
    import cv2, scipy, numpy as np

    if copy: 
        target_for_interp = target_for_interp.copy()

    def getPixelsForInterp(img): 
        """
        Calculates a mask of pixels neighboring invalid values - 
           to use for interpolation. 
        """
        # mask invalid pixels
        invalid_mask = np.isnan(img) + (img == 0) 
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))

        #dilate to mark borders around invalid regions
        dilated_mask = cv2.dilate(invalid_mask.astype('uint8'), kernel, 
                          borderType=cv2.BORDER_CONSTANT, borderValue=int(0))

        # pixelwise "and" with valid pixel mask (~invalid_mask)
        masked_for_interp = dilated_mask *  ~invalid_mask
        return masked_for_interp.astype('bool'), invalid_mask

    # Mask pixels for interpolation
    mask_for_interp, invalid_mask = getPixelsForInterp(target_for_interp)

    # Interpolate only holes, only using these pixels
    points = np.argwhere(mask_for_interp)
    values = target_for_interp[mask_for_interp]
    interp = interpolator(points, values)

    target_for_interp[invalid_mask] = interp(np.argwhere(invalid_mask))
    return target_for_interp

# For the target tensor: 
target_filled = fillMissingValues(target.numpy().squeeze())

# transform back to tensor etc..

Обратите внимание, что интерполированные значения будут np.nan вне выпуклой оболочки действительной баллы, как указано LinearNDInterpolator.

...