Обновление
Более эффективная и короткая версия. Чтобы избежать использования цикла for, мы можем сначала переставить a
.
import torch
a = torch.arange(9*2*2).view(9,1,2,2)
b = a.permute([0,1,3,2])
torch.cat(torch.split(b, 3),-1).view(6,6).t()
# tensor([[ 0, 1, 4, 5, 8, 9],
# [ 2, 3, 6, 7, 10, 11],
# [12, 13, 16, 17, 20, 21],
# [14, 15, 18, 19, 22, 23],
# [24, 25, 28, 29, 32, 33],
# [26, 27, 30, 31, 34, 35]])
Оригинальный ответ
Вы можете использовать torch.split
и torch.cat
для реализацииЭто.
import torch
a = torch.arange(9*2*2).view(9,1,2,2)
Предполагается, что у нас есть тензор a
, который является мини-версией вашего оригинального тензора. И, похоже,
tensor([[[[ 0, 1],
[ 2, 3]]],
[[[ 4, 5],
[ 6, 7]]],
[[[ 8, 9],
[10, 11]]],
[[[12, 13],
[14, 15]]],
[[[16, 17],
[18, 19]]],
[[[20, 21],
[22, 23]]],
[[[24, 25],
[26, 27]]],
[[[28, 29],
[30, 31]]],
[[[32, 33],
[34, 35]]]])
Каждая подматрица 2x2 может рассматриваться как одно изображение. То, что вы хотите сделать, это сложить первые три изображения в один ряд, следующие три изображения во второй ряд и последние три изображения в третий ряд. «Строка» на самом деле имеет два затемнения из-за подматрицы 2x2.
three_parts = torch.split(a,3)
torch.cat(torch.split(three_parts[0],1), dim=-1)
#tensor([[[[ 0, 1, 4, 5, 8, 9],
# [ 2, 3, 6, 7, 10, 11]]]])
Здесь мы берем только первую часть.
torch.cat([torch.cat(torch.split(three_parts[i],1),-1) for i in range(3)],0).view(6,6)
# tensor([[ 0, 1, 4, 5, 8, 9],
# [ 2, 3, 6, 7, 10, 11],
# [12, 13, 16, 17, 20, 21],
# [14, 15, 18, 19, 22, 23],
# [24, 25, 28, 29, 32, 33],
# [26, 27, 30, 31, 34, 35]])