Ключевой особенностью здесь является передача значений тензора lengths
в виде индексов для x
. Здесь, в упрощенном примере, я поменял местами размеры контейнера, поэтому сначала идет размерность индекса:
container = torch.arange(0, 50 )
container = f.reshape((5, 10))
>>>tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
[20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
[30, 31, 32, 33, 34, 35, 36, 37, 38, 39],
[40, 41, 42, 43, 44, 45, 46, 47, 48, 49]])
indices = torch.arange( 2, 7, dtype=torch.long )
>>>tensor([2, 3, 4, 5, 6])
print( container[ range( len(indices) ), indices] )
>>>tensor([ 2, 13, 24, 35, 46])
Примечание: мы получили одну вещь из строки (range( len(indices) )
делает последовательные номера строк) с номером столбца, заданным индексами [row_number]