Если вы используете только среднее значение для 9 элементов, центрированных в каждом пикселе, тогда лучшим вариантом будет использование двумерной свертки с постоянным фильтром 3x3:
import torch.nn.functional as nnf
def mean_filter(x_bchw):
"""
Calculating the mean of each 3x3 neighborhood.
input:
- x_bchw: input tensor of dimensions batch-channel-height-width
output:
- y_bchw: each element in y is the average of the 9 corresponding elements in x_bchw
"""
# define the filter
box = torch.ones((3, 3), dtype=x_bchw.dtype, device=x_bchw.device, requires_grad=False)
box = box / box.sum()
box = box[None, None, ...].repeat(x_bchw.size(1), 1, 1, 1)
# use grouped convolution - so each channel is averaged separately.
y_bchw = nnf.conv2d(x_bchw, box, padding=1, groups=x_bchw.size(1))
return y_bchw
однако, если вы хотите чтобы применить более сложную функцию для каждого района, вы можете использовать nn.Unfold
. Эта операция преобразует каждый 3x3 (или любую определенную вами область angular прямоугольника) в вектор. Когда у вас есть все векторы, вы можете применить к ним свою функцию.
См. этот ответ для получения более подробной информации о unfold
и fold
.