Извлечь все блоки 3х3 из тензора - PullRequest
0 голосов
/ 14 марта 2020

Если у меня есть тензор 5x5, то как мне получить из него все 9 блоков 3x3, чтобы полученный тензор имел форму [9, 3, 3] или если эти блоки 3x3 сглажены, то [9, 9] форма. например,

x = torch.randn(5, 5)

предположим, что х равен

tensor([[ 0.5756,  0.2463,  1.3940,  0.8473, -0.8371],
        [ 0.9690,  1.4913, -0.2129,  0.8331, -0.6322],
        [-0.0348, -1.6920, -0.0157,  0.6159,  0.1038],
        [-1.0790,  1.4303,  0.3861,  0.1293,  0.4582],
        [ 0.2815, -1.1944, -0.7612,  0.6595,  1.4611]])

, тогда результирующий тензор должен быть похож на

tensor([[0.5756,  0.2463,  1.3940, 0.9690,  1.4913, -0.2129, -0.0348, -1.6920, -0.0157],
 [0.2463, 1.3940,  0.8473, 1.4913, -0.2129,  0.8331, -1.6920, -0.0157,  0.6159],
...
[-0.0157,  0.6159,  0.1038, 0.3861,  0.1293,  0.4582, -0.7612,  0.6595,  1.4611]])

1 Ответ

0 голосов
/ 14 марта 2020

Очень наивная реализация может быть

y = torch.randn(5, 5)
x = torch.zeros((9, 3, 3))
count = 0
for i in range(3) :
    for j in range(3) :
        x[count] = y[i : i + 3, j : j + 3]
        count += 1

Пример вывода:

y = tensor([[ 0.0361, -0.4931, -1.1977, -0.5224, -3.4067],
        [ 0.2380, -1.1042, -0.0696, -2.0487, -0.4123],
        [ 0.6567, -0.2485, -0.3954, -0.8197, -0.4903],
        [ 1.0073,  1.4759,  0.3532,  0.3565, -1.5257],
        [-0.8493, -0.0532,  1.0918,  1.2715, -0.1775]])

x = tensor([[[ 0.0361, -0.4931, -1.1977],
         [ 0.2380, -1.1042, -0.0696],
         [ 0.6567, -0.2485, -0.3954]],

        [[-0.4931, -1.1977, -0.5224],
         [-1.1042, -0.0696, -2.0487],
         [-0.2485, -0.3954, -0.8197]],

        [[-1.1977, -0.5224, -3.4067],
         [-0.0696, -2.0487, -0.4123],
         [-0.3954, -0.8197, -0.4903]],

        [[ 0.2380, -1.1042, -0.0696],
         [ 0.6567, -0.2485, -0.3954],
         [ 1.0073,  1.4759,  0.3532]],

        [[-1.1042, -0.0696, -2.0487],
         [-0.2485, -0.3954, -0.8197],
         [ 1.4759,  0.3532,  0.3565]],

        [[-0.0696, -2.0487, -0.4123],
         [-0.3954, -0.8197, -0.4903],
         [ 0.3532,  0.3565, -1.5257]],

        [[ 0.6567, -0.2485, -0.3954],
         [ 1.0073,  1.4759,  0.3532],
         [-0.8493, -0.0532,  1.0918]],

        [[-0.2485, -0.3954, -0.8197],
         [ 1.4759,  0.3532,  0.3565],
         [-0.0532,  1.0918,  1.2715]],

        [[-0.3954, -0.8197, -0.4903],
         [ 0.3532,  0.3565, -1.5257],
         [ 1.0918,  1.2715, -0.1775]]])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...