Pytorch: Как найти индексы первого ненулевого элемента в каждой строке двумерного тензора? - PullRequest
1 голос
/ 11 мая 2019

У меня есть 2D-тензор с некоторым ненулевым элементом в каждой строке, например:

import torch
tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
                    [0, 0, 0, 1, 1, 0, 0]], dtype=torch.float)

Я хочу, чтобы тензор содержал индекс первого ненулевого элемента в каждой строке:

indices = tensor([2],
                 [3])

Как я могу рассчитать это в Pytorch?

1 Ответ

0 голосов
/ 12 мая 2019

Я мог бы найти хитрый ответ на мой вопрос:

  tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
                     [0, 0, 0, 1, 1, 0, 0]], dtype=torch.float)
  idx = reversed(torch.Tensor(range(1,8)))
  print(idx)

  tmp2= torch.einsum("ab,b->ab", (tmp, idx))

  print(tmp2)

  indices = torch.argmax(tmp2, 1, keepdim=True)
  print(indeces)

Результат:

tensor([7., 6., 5., 4., 3., 2., 1.])
tensor([[0., 0., 5., 0., 3., 0., 0.],
       [0., 0., 0., 4., 3., 0., 0.]])
tensor([[2],
        [3]])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...