Тензор структуры на основе индекса - PullRequest
0 голосов
/ 14 мая 2019

Я пытаюсь объединить несколько тензоров на основе соответствующего тензора индекса. В качестве примера игрушки у меня есть тензор A = [[0,1,2],[2,3,4]] с соответствующим индексным вектором B = [0,1] и другой тензор C = [[5,6,7],[8,9,10],[11,12,13]] с соответствующим индексным вектором D = [0,1,2]. Я стремлюсь вернуть тензор такой, чтобы E = [[0,1,2],[5,6,7],[2,3,4],[8,9,10],[11,12,13]], где значения с одинаковыми индексами сгруппированы вместе. Я пытаюсь эффективно сделать это, поскольку оно содержится в моей модели NN и может страдать от проблем с производительностью во время обратного распространения.

Я смог добиться этого, создав словарь и выполнив цикл как таковой:

E = defaultdict(list)
for vector, indices in zip([A,C],[B,D]):
    for vector, i in enumerate(vector):
        E[indices[i]].append(vector)
E = list(E.values())

Однако такой подход значительно замедляет обучение модели. Есть ли способ сделать это более эффективно? Кажется, разные операции Pytorch над index_ не решают эту проблему. Спасибо!

1 Ответ

0 голосов
/ 14 мая 2019

Я предполагаю, что тензоры A и C имеют разное количество строк, но одинаковое количество столбцов.

A = torch.LongTensor([[0,1,2],[2,3,4]]) # torch.Size([2, 3])
C = torch.LongTensor([[5,6,7],[8,9,10],[11,12,13]]) # torch.Size([3, 3])

assert A.dim() == C.dim() == 2
assert A.size(1) == C.size(1)

shape_E = (A.size(0) + C.size(0), A.size(1))
E = torch.zeros(*shape_E).long()

# Two possible cases
if A.size(0) > C.size(0): # C has fewer rows
    n_rows = C.size(0)
    E[np.arange(0, 2*n_rows, 2), :] = A[:n_rows, :]
    E[np.arange(1, 2*n_rows, 2), :] = C[:n_rows, :]
    E[2*n_rows:, :] = A[n_rows:, :]
else: # A has fewer rows
    n_rows = A.size(0)
    E[np.arange(0, 2*n_rows, 2), :] = A[:n_rows, :]
    E[np.arange(1, 2*n_rows, 2), :] = C[:n_rows, :]
    E[2*n_rows:, :] = C[n_rows:, :]

print(E)

Вывод

tensor([[ 0,  1,  2],
        [ 5,  6,  7],
        [ 2,  3,  4],
        [ 8,  9, 10],
        [11, 12, 13]])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...