Найти заданный c индекс элемента из Tensor (Matrix) - PullRequest
1 голос
/ 07 февраля 2020

Я студент, только начинающий изучать глубокое обучение.

   x_norm = (x**2).sum(1).view(-1, 1)
   if y is not None:
       y_norm = (y**2).sum(1).view(1, -1)
   else:
       y = x
       y_norm = x_norm.view(1, -1)
   ## NOTICE ##
   dist = torch.exp(-1*(x_norm + y_norm - 2.0 * torch.mm(x, torch.transpose(y, 0, 1))))
   return dist

dist = pairwise_distances(atom_s[:3,-3:])
zero_mat=torch.zeros_like(dist,dtype=torch.float)
dist= torch.where(dist>exp(-8),dist,zero_mat)

Выше мое кодирование для создания попарной карты расстояний. И измените некоторый элемент, который удовлетворяет условию, на 0.

Вопрос заключается в том, «как я могу получить индексы элементов, которые задают условие c (например, больше> 0,5) ?? без использования медленного« для » л oop.

1 Ответ

0 голосов
/ 07 февраля 2020

Если это нормальный тензор, вы можете использовать torch.nonzero

>>> (dist > 0.5).nonzero()

, который будет возвращать индексы всех элементов, которые больше 0.5

Пример:

>>> dist = torch.rand((6,5))
>>> dist
tensor([[0.7549, 0.0962, 0.3198, 0.6868, 0.8117],
        [0.0785, 0.7666, 0.2623, 0.5140, 0.2713],
        [0.5768, 0.8160, 0.8654, 0.6978, 0.0138],
        [0.8147, 0.1394, 0.3204, 0.0104, 0.2872],
        [0.1396, 0.5639, 0.7085, 0.7151, 0.8253],
        [0.6115, 0.0214, 0.6033, 0.1403, 0.1977]])

>>> (dist > 0.5).nonzero()
tensor([[0, 0],
        [0, 3],
        [0, 4],
        [1, 1],
        [1, 3],
        [2, 0],
        [2, 1],
        [2, 2],
        [2, 3],
        [3, 0],
        [4, 1],
        [4, 2],
        [4, 3],
        [4, 4],
        [5, 0],
        [5, 2]])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...