Индексация 3d-тензора с использованием 2d-тензора - PullRequest
2 голосов
/ 11 апреля 2019

У меня есть 3-й тензор source формы (bsz x slen1 x nhd) и 2-й тензор index формы (bsz x slen2). Точнее, у меня есть:

source = 32 x 20 x 768
index  = 32 x 16

Каждое значение в тензоре index находится в интервале [0, 19], который является индексом нужного вектора в соответствии со вторым значением тензора source.

После индексации я ожидаю выходной тензор формы, 32 x 16 x 768.

В настоящее время я делаю это:

bsz, _, nhid = source.size()
_, slen = index.size()

source = source.reshape(-1, nhid)
source = source[index.reshape(-1), :]
source = source.reshape(bsz, slen, nhid)

Итак, я преобразовываю тензор 3d-источника в 2d-тензор и 2d-индексирующий тензор в 1d-тензор, а затем выполняю индексацию. Это правильно?

Есть ли лучший способ сделать это?

Update

Я проверил, что мой код не дает ожидаемого результата. Чтобы объяснить, чего я хочу, я предоставляю следующий фрагмент кода.

source = torch.FloatTensor([
    [[ 0.2413, -0.6667,  0.2621],
     [-0.4216,  0.3722, -1.2258],
     [-0.2436, -1.5746, -0.1270],
     [ 1.6962, -1.3637,  0.8820],
     [ 0.3490, -0.0198,  0.7928]],

    [[-0.0973,  2.3106, -1.8358],
     [-1.9674,  0.5381,  0.2406],
     [ 3.0731,  0.3826, -0.7279],
     [-0.6262,  0.3478, -0.5112],
     [-0.4147, -1.8988, -0.0092]]
     ])

index = torch.LongTensor([[0, 1, 2, 3], 
                          [1, 2, 3, 4]])

И я хочу, чтобы тензор вывода был:

torch.FloatTensor([
    [[ 0.2413, -0.6667,  0.2621],
     [-0.4216,  0.3722, -1.2258],
     [-0.2436, -1.5746, -0.1270],
     [ 1.6962, -1.3637,  0.8820]],

    [[-1.9674,  0.5381,  0.2406],
     [ 3.0731,  0.3826, -0.7279],
     [-0.6262,  0.3478, -0.5112],
     [-0.4147, -1.8988, -0.0092]]
     ])

Ответы [ 2 ]

2 голосов
/ 12 апреля 2019

Я решил проблему. Итак, мне нужно было определить смещение. Следующий код работает для меня.

index = torch.LongTensor([[0, 1, 2, 3], [1, 2, 3, 4]])
offset = torch.arange(0, source.size(0) * source.size(1), source.size(1))
index = index + offset.unsqueeze(1)

source = source.reshape(-1, source.shape[-1])[index]
1 голос
/ 12 апреля 2019

Обновление :

source[torch.arange(source.shape[0]).unsqueeze(-1), index]

Обратите внимание, что torch.arange(source.shape[0]).unsqueeze(-1) дает:

tensor([[0],
        [1]])  # 2 x 1

и index:

tensor([[0, 1, 2, 3],
        [1, 2, 3, 4]])  # 2 x 4

arange индексирует размерность пакета, а index одновременно индексирует размерность slen1.Вызов unsqueeze добавляет дополнительное измерение x 1 к результату arange, так что они могут передаваться вместе.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...