Быстро найти индексы, значения которых превышают пороговое значение в Numpy / PyTorch - PullRequest
0 голосов
/ 26 апреля 2018

Задача

Учитывая матрицу numpy или pytorch, найдите индексы ячеек, значения которых превышают заданный порог.

Моя реализация

#abs_cosine is the matrix
#sim_vec is the wanted

sim_vec = []
for m in range(abs_cosine.shape[0]):
    for n in range(abs_cosine.shape[1]):
        # exclude diagonal cells
        if m != n and abs_cosine[m][n] >= threshold:
            sim_vec.append((m, n))

Беспокойство

Скорость . Все остальные вычисления построены на Pytorch, использование numpy уже является компромиссом, потому что оно перенесло вычисления с GPU на CPU. Чистые циклы python for сделают весь процесс еще хуже (для небольших наборов данных уже в 5 раз медленнее). Мне было интересно, можем ли мы переместить все вычисления в Numpy (или pytorch), не вызывая циклов for?

Улучшение, которое я могу придумать (но застряло ...)

bool_cosine = abs_cosine> порог

, который возвращает логическую матрицу True и False. Но я не могу найти способ быстро получить индексы ячеек True.

Ответы [ 2 ]

0 голосов
/ 26 апреля 2018

Следующее для PyTorch (полностью на GPU)

# abs_cosine should be a Tensor of shape (m, m)
mask = torch.ones(abs_cosine.size()[0])
mask = 1 - mask.diag()
sim_vec = torch.nonzero((abs_cosine >= threshold)*mask)

# sim_vec is a tensor of shape (?, 2) where the first column is the row index and second is the column index

следующие работы в numpy

mask = 1 - np.diag(np.ones(abs_cosine.shape[0]))
sim_vec = np.nonzero((abs_cosine >= 0.2)*mask)
# sim_vec is a 2-array tuple where the first array is the row index and the second array is column index
0 голосов
/ 26 апреля 2018

Это примерно в два раза быстрее, чем np.where

import numba as nb
@nb.njit(fastmath=True)

def get_threshold(abs_cosine,threshold):
  idx=0
  sim_vec=np.empty((abs_cosine.shape[0]*abs_cosine.shape[1],2),dtype=np.uint32)
  for m in range(abs_cosine.shape[0]):
    for n in range(abs_cosine.shape[1]):
      # exclude diagonal cells
      if m != n and abs_cosine[m,n] >= threshold:
        sim_vec[idx,0]=m
        sim_vec[idx,1]=n
        idx+=1

  return sim_vec[0:idx,:]

Первый вызов длится примерно 0,2 с (накладные расходы на компиляцию). Если массив находится на графическом процессоре, может также быть способ выполнить все вычисления на графическом процессоре.

Тем не менее, я не очень доволен производительностью, поскольку простая логическая операция примерно в 5 раз быстрее, чем решение, показанное выше, и в 10 раз быстрее, чем np.where. Если порядок индексов не имеет значения, эту проблему можно распараллелить.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...