Нахождение вычитания сдвинутого тензора - PullRequest
0 голосов
/ 25 сентября 2019

Я пытаюсь выяснить, как сделать сдвиг на тензоре, у которого b (batch size), d (depth), h (hight) and w (width) представлен следующим образом:

b, d, h, w = tensor.size()

Итак, мне нужно найти вычитание между смещенным тензором и тензоромсам.

Я думаю об использовании torch.narrow или torch.concat, чтобы сделать это для каждой стороны (сдвиг вправо, влево, вверх, затем вниз) и каждый раз вычитать из одной и той же тензорной стороны.(сторона самого тензора), затем в конце я добавлю / сложу разности / вычитания каждой стороны (поэтому у меня будет окончательное вычитание между сдвинутым и самим тензором.

Я новичок в PyTorch, это легко понять, но трудно реализовать, и, возможно, есть более простой способ (непосредственно вычитание, а не работа с каждой стороны и т. д.)

Любая помощь по этому вопросу, пожалуйста?

1 Ответ

0 голосов
/ 26 сентября 2019

По сути, вы можете сначала разбить тензор, а затем упорядочить их в обратном порядке.Я пишу функцию для реализации ваших мыслей.shift должно быть неотрицательным числом и меньше или равно размеру dim.

def tensor_shift(t, dim, shift):
    """
    t (tensor): tensor to be shifted. 
    dim (int): the dimension apply shift.
    shift (int): shift distance.
    """
    assert 0 <= shift <= t.size(dim), "shift distance should be smaller than or equal to the dim length."

    overflow = t.index_select(dim, torch.arange(t.size(dim)-shift, t.size(dim)))
    remain = t.index_select(dim, torch.arange(t.size(dim)-shift))

    return torch.cat((overflow, remain),dim=dim)

Вот некоторые результаты теста.

a = torch.arange(1,13).view(-1,3)
a
#tensor([[ 1,  2,  3],
#        [ 4,  5,  6],
#        [ 7,  8,  9],
#        [10, 11, 12]])

shift(a, 0, 1) # shift 1 unit along dim=0
#tensor([[10, 11, 12],
#        [ 1,  2,  3],
#        [ 4,  5,  6],
#        [ 7,  8,  9]])

b = torch.arange(1,13).view(-1,2,3)
b
#tensor([[[ 1,  2,  3],
#         [ 4,  5,  6]],
#
#        [[ 7,  8,  9],
#         [10, 11, 12]]])

shift(b, 1, 1) # shift 1 unit along dim=1
#tensor([[[ 4,  5,  6],
#         [ 1,  2,  3]],
#
#        [[10, 11, 12],
#         [ 7,  8,  9]]])
...