Нарезка неравномерных столбцов из тензорного массива - PullRequest
0 голосов
/ 08 мая 2019

У меня есть такой массив:

([[[ 0,  1,  2],
 [ 3,  4,  5]],

[[ 6,  7,  8],
[ 9, 10, 11]],

[[12, 13, 14],
[15, 16, 17]]])

Если я хочу нарезать числа от 12 до 17, я бы использовал:

arr[2, 0:2, 0:3]

но как мне нарезать массив, чтобы получить от 12 до 16?

Ответы [ 2 ]

2 голосов
/ 08 мая 2019

Сначала вам нужно «сплющить» последние два измерения.Только тогда вы сможете извлечь нужные элементы:

xf = x.view(x.size(0), -1)  # flatten the last dimensions
xf[2, 0:5]
Out[87]: tensor([12, 13, 14, 15, 16])
0 голосов
/ 08 мая 2019

Другим способом было бы просто индексировать в тензор и нарезать то, что необходимо, как в:

# input tensor 
t = tensor([[[ 0,  1,  2],
             [ 3,  4,  5]],

           [[ 6,  7,  8],
            [ 9, 10, 11]],

           [[12, 13, 14],
            [15, 16, 17]]])

# slice the last `block`, then flatten it and 
# finally slice all elements but the last one
In [10]: t[-1].view(-1)[:-1]   
Out[10]: tensor([12, 13, 14, 15, 16])

Обратите внимание, что, поскольку это базовая нарезка, она возвращает вид .Таким образом, внесение любых изменений в нарезанную часть также повлияет на исходный тензор.Например:

# assign it to some variable name
In [11]: sliced = t[-1].view(-1)[:-1] 
In [12]: sliced      
Out[12]: tensor([12, 13, 14, 15, 16])

# modify one element
In [13]: sliced[-1] = 23   

In [14]: sliced  
Out[14]: tensor([12, 13, 14, 15, 23])

# now, the original tensor is also updated
In [15]: t  
Out[15]: 
tensor([[[ 0,  1,  2],
         [ 3,  4,  5]],

        [[ 6,  7,  8],
         [ 9, 10, 11]],

        [[12, 13, 14],
         [15, 23, 17]]])
...