соседи клетки в матричном питорче - PullRequest
1 голос
/ 21 января 2020

Я пытаюсь получить соседей ячейки матрицы в pytorch, используя приведенную ниже часть кода. это работает правильно, но это очень много времени. Есть ли у вас предложения, чтобы получить его быстрее

def neighbour(x):
    result=F.pad(input=x, pad=(1, 1, 1, 1), mode='constant', value=0)
    for m in range(1,x.size(0)+1):
        for n in range(1,x.size(1)+1):
                y=torch.Tensor([result[m][n],result[m-1][n-1],result[m-1][n],result[m-1] 
           [n+1],result[m][n-1],result[m][n+1],result[m+1][n-1],result[m+1][n],result[m+1][n+1]])
                x[m-1][n-1]=y.mean()

    return x

1 Ответ

1 голос
/ 21 января 2020

Если вы используете только среднее значение для 9 элементов, центрированных в каждом пикселе, тогда лучшим вариантом будет использование двумерной свертки с постоянным фильтром 3x3:

import torch.nn.functional as nnf

def mean_filter(x_bchw):
  """
  Calculating the mean of each 3x3 neighborhood.
  input:
    - x_bchw: input tensor of dimensions batch-channel-height-width
  output:
    - y_bchw: each element in y is the average of the 9 corresponding elements in x_bchw
  """
  # define the filter
  box = torch.ones((3, 3), dtype=x_bchw.dtype, device=x_bchw.device, requires_grad=False)  
  box = box / box.sum()
  box = box[None, None, ...].repeat(x_bchw.size(1), 1, 1, 1)
  # use grouped convolution - so each channel is averaged separately.  
  y_bchw = nnf.conv2d(x_bchw, box, padding=1, groups=x_bchw.size(1))
  return y_bchw

однако, если вы хотите чтобы применить более сложную функцию для каждого района, вы можете использовать nn.Unfold. Эта операция преобразует каждый 3x3 (или любую определенную вами область angular прямоугольника) в вектор. Когда у вас есть все векторы, вы можете применить к ним свою функцию.
См. этот ответ для получения более подробной информации о unfold и fold.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...