По сути, вы можете сначала разбить тензор, а затем упорядочить их в обратном порядке.Я пишу функцию для реализации ваших мыслей.shift
должно быть неотрицательным числом и меньше или равно размеру dim
.
def tensor_shift(t, dim, shift):
"""
t (tensor): tensor to be shifted.
dim (int): the dimension apply shift.
shift (int): shift distance.
"""
assert 0 <= shift <= t.size(dim), "shift distance should be smaller than or equal to the dim length."
overflow = t.index_select(dim, torch.arange(t.size(dim)-shift, t.size(dim)))
remain = t.index_select(dim, torch.arange(t.size(dim)-shift))
return torch.cat((overflow, remain),dim=dim)
Вот некоторые результаты теста.
a = torch.arange(1,13).view(-1,3)
a
#tensor([[ 1, 2, 3],
# [ 4, 5, 6],
# [ 7, 8, 9],
# [10, 11, 12]])
shift(a, 0, 1) # shift 1 unit along dim=0
#tensor([[10, 11, 12],
# [ 1, 2, 3],
# [ 4, 5, 6],
# [ 7, 8, 9]])
b = torch.arange(1,13).view(-1,2,3)
b
#tensor([[[ 1, 2, 3],
# [ 4, 5, 6]],
#
# [[ 7, 8, 9],
# [10, 11, 12]]])
shift(b, 1, 1) # shift 1 unit along dim=1
#tensor([[[ 4, 5, 6],
# [ 1, 2, 3]],
#
# [[10, 11, 12],
# [ 7, 8, 9]]])