Пары точек отбора проб из сетки в Pytorch - PullRequest
0 голосов
/ 21 октября 2019

Мне нужны пары точек выборки из сетки в PyTorch.

У меня есть тензор размера (1 x 500 x 1000). У меня также есть тензор размера маски (1 x 500 x 1000), обозначающий, является ли точка действительной или нет. Я хочу взять 200k точек пары из этой сетки. Другими словами, я хочу получить координаты выбранных пар точек в виде тензора размера (200k x 4), обозначающего (x1, y1, x2, y2) для всех пар 200k точек. Все точки в парах должны быть действительными точками.

Это будет повторяться много раз, поэтому мне нужен эффективный способ выполнения этой процедуры. Что такое элегантный способ реализовать это в PyTorch?

1 Ответ

0 голосов
/ 31 октября 2019

Здесь не эксперт, но я потратил некоторое время на то, чтобы попробовать.
Оказывается, работать с массивом 1D намного быстрее (второй метод).

import time
import torch
class Timer():
    def __init__(self):
        pass
    def __enter__(self):
        self.time = time.time()
    def __exit__(self, *exc):
        print(f'time used: {time.time() - self.time:.2f}s')

# a = torch.rand([1,500,1000])
m = torch.randint(2, [1, 500, 1000]) # mask tensor
valid_len = (m==1).nonzero().size()[0] # number of valid points
rand_one = torch.randint(valid_len, [200000]) # sample 200k of random int
rand_two = torch.randint(valid_len, [200000]) # sample 200k of random int

# method one
m0 = m == 1 # mask of shape torch.Size([1, 500, 1000])
m0 = m0.nonzero() # valid points of shape torch.Size([valid_len, 3])
m0 = m0[:, 1:] # reshape to shape torch.Size([valid_len, 2])
with Timer():
    one0 = torch.index_select(m0, 0, rand_one) # take 200k valid points
    two0 = torch.index_select(m0, 0, rand_two) # take 200k valid points again
    coor0 = torch.cat([one0, two0], dim=1) # stack them up
# >>> time used: 1.05s

# method two
m1 = m.reshape(-1) # reshape mask to torch.Size([500000])
m1 = m1==1 # mask of shape torch.Size([500000])
m1 = m1.nonzero() # valid points of shape torch.Size([valid_len, 1])
m1 = m1.reshape(-1) # valid points of shape torch.Size([valid_len])
with Timer():
    one1 = m1.take(rand_one) # take 200k valid points
    two1 = m1.take(rand_two) # again
    # transform them to coordinates and stack them up
    coor1 = torch.stack([one1 // 1000, one1 % 1000, two1 // 1000, two1 % 1000], dim=1)
# >>> time used: 0.07s

assert torch.sum(coor0 == coor1) == 800000 # make sure consistent result 

ура

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...