Гауссов фильтр в PyTorch - PullRequest
       0

Гауссов фильтр в PyTorch

0 голосов
/ 05 марта 2020

Я ищу способ применить фильтр Гаусса к изображению (тензор) только с использованием функций PyTorch. Используя numpy, эквивалентный код

import numpy as np
from scipy import signal
import matplotlib.pyplot as plt

# Define 2D Gaussian kernel
def gkern(kernlen=256, std=128):
    """Returns a 2D Gaussian kernel array."""
    gkern1d = signal.gaussian(kernlen, std=std).reshape(kernlen, 1)
    gkern2d = np.outer(gkern1d, gkern1d)
    return gkern2d

# Generate random matrix and multiply the kernel by it
A = np.random.rand(256*256).reshape([256,256])

# Test plot
plt.figure()
plt.imshow(A*gkern(256, std=32))
plt.show()

Самое близкое предложение, которое я нашел, основано на этом посте :

import torch.nn as nn

conv = nn.Conv2d(in_channels = 1, out_channels = 1, kernel_size=264, bias=False)
with torch.no_grad():
    conv.weight = gaussian_weights

Но оно дает мне ошибку NameError: name 'gaussian_weights' is not defined. Как я могу заставить это работать?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...