Назначить тензор на несколько срезов - PullRequest
2 голосов
/ 29 сентября 2019

Пусть

a = tensor([[0, 0, 0, 0],
            [0, 0, 0, 0],
            [0, 0, 0, 0]])
b = torch.tensor([1, 2])
c = tensor([[1, 2, 0, 0],
            [0, 1, 2, 0],
            [0, 0, 1, 2]])

Есть ли способ получить c, назначив b срезами a без петель?То есть a[indices] = b для какого-то indices или чего-то подобного?

Ответы [ 2 ]

3 голосов
/ 29 сентября 2019

Вы можете использовать метод scatter в pytorch.

a = torch.tensor([[0, 0, 0, 0],
                 [0, 0, 0, 0],
                 [0, 0, 0, 0]])

b = torch.tensor([1, 2])

index = torch.tensor([[0,1],[1,2],[2,3]])

a.scatter_(1, index, b.view(-1,2).repeat(3,1))
# tensor([[1, 2, 0, 0],
#         [0, 1, 2, 0],
#         [0, 0, 1, 2]])
2 голосов
/ 30 сентября 2019

Логика этой операции немного сомнительна в том смысле, что неясно, каковы параметры операции.Тем не менее, один из способов получения желаемого результата из входных данных только с помощью векторизованных операций:

  • определяет, сколько строк необходимо (3 для вашего примера)
  • create a a с таким количеством столбцов, что за b следует столько нулей, сколько число строк (2 + 3) и выбранное количество строк (3)
  • assign bк началу a для каждого
  • сгладьте массив, обрежьте num_rows нули и измените форму на целевую форму.

В NumPy это может быть реализовано следующим образом:

import numpy as np


b = np.array([1, 2])
c = np.array([[1, 2, 0, 0],
              [0, 1, 2, 0],
              [0, 0, 1, 2]])

num_rows = 3
a = np.zeros((num_rows, len(b) + num_rows), dtype=b.dtype)
a[:, :len(b)] = b
a = a.ravel()[:-num_rows].reshape((num_rows, len(b) + num_rows - 1))

print(a)
# [[1 2 0 0]
#  [0 1 2 0]
#  [0 0 1 2]]

print(np.all(a == c))
# True

РЕДАКТИРОВАТЬ

Тот же самый подход, реализованный в Torch:

import torch as to


b = to.tensor([1, 2])
c = to.tensor([[1, 2, 0, 0],
               [0, 1, 2, 0],
               [0, 0, 1, 2]])

num_rows = 3
a = to.zeros((num_rows, len(b) + num_rows), dtype=b.dtype)
a[:, :len(b)] = b
a = a.flatten()[:-num_rows].reshape((num_rows, len(b) + num_rows - 1))

print(a)
# tensor([[1, 2, 0, 0],
#         [0, 1, 2, 0],
#         [0, 0, 1, 2]])

print(to.all(a == c))
# tensor(1, dtype=torch.uint8)
...