Как я могу нарезать тензор PyTorch с другим тензором? - PullRequest
3 голосов
/ 18 апреля 2020

У меня есть:

inp =  torch.randn(4, 1040, 161)

, и у меня есть другой тензор с именем indices со значениями:

tensor([[124, 583, 158, 529],
        [172, 631, 206, 577]], device='cuda:0')

Я хочу эквивалент:

inp0 = inp[:,124:172,:]
inp1 = inp[:,583:631,:]
inp2 = inp[:,158:206,:]
inp3 = inp[:,529:577,:]

За исключением того, что все сложено вместе, иметь размер. [4, 48, 161]. Как я могу выполнить sh это?

В настоящее время мое решение: for l oop:

            left_indices = torch.empty(inp.size(0), self.side_length, inp.size(2))
            for batch_index in range(len(inp)):
                print(left_indices_start[batch_index].item())
                left_indices[batch_index] = inp[batch_index, left_indices_start[batch_index].item():left_indices_end[batch_index].item()]

Ответы [ 2 ]

1 голос
/ 18 апреля 2020
inp =  torch.randn(4, 1040, 161)   
indices = torch.tensor([[124, 583, 158, 529],
            [172, 631, 206, 577]])
k = zip(indices[0], indices[1])
for i,j in k:
    print(inp[:,i:j,:])

Вы можете реализовать это следующим образом ... Функция zip помогает преобразовать тензор ваших индексов в список кортежей, которые вы можете использовать напрямую для l oop

Надеюсь, это поможет вам. ...

1 голос
/ 18 апреля 2020

Здесь вы go (РЕДАКТИРОВАТЬ: вам, вероятно, нужно скопировать тензоры в процессор, используя tensor=tensor.cpu() перед выполнением следующих операций):

index = tensor([[124, 583, 158, 529],
    [172, 631, 206, 577]], device='cuda:0')
#create a concatenated list of ranges of indices you desire to slice
indexer = np.r_[tuple([np.s_[i:j] for (i,j) in zip(index[0,:],index[1,:])])]
#slice using numpy indexing
sliced_inp = inp[:, indexer, :]

Вот как это работает:

np.s_[i:j] создает объект среза (просто диапазон) индексов от начала = i до конца = j.

np.r_[i:j, k:m] создает список ВСЕХ индексов в срезах (i,j) и (k,m) (Вы можете передать большее количество срезов в np.r_, чтобы объединить их все вместе за один раз. Это пример объединения только двух срезов .)

Следовательно, indexer создает список ВСЕХ индексов путем объединения списка срезов (каждый срез является диапазоном индексов).

ОБНОВЛЕНИЕ: Если вам необходимо удалить интервальные наложения и сортировать интервалы:

indexer = np.unique(indexer)

, если вы хотите удалить интервальные перекрытия, но не сортировать и не сохранять исходный порядок (и первые появления перекрытий)

uni = np.unique(indexer, return_index=True)[1]
indexer = [indexer[index] for index in sorted(uni)]
...