Здесь не эксперт, но я потратил некоторое время на то, чтобы попробовать.
Оказывается, работать с массивом 1D намного быстрее (второй метод).
import time
import torch
class Timer():
def __init__(self):
pass
def __enter__(self):
self.time = time.time()
def __exit__(self, *exc):
print(f'time used: {time.time() - self.time:.2f}s')
# a = torch.rand([1,500,1000])
m = torch.randint(2, [1, 500, 1000]) # mask tensor
valid_len = (m==1).nonzero().size()[0] # number of valid points
rand_one = torch.randint(valid_len, [200000]) # sample 200k of random int
rand_two = torch.randint(valid_len, [200000]) # sample 200k of random int
# method one
m0 = m == 1 # mask of shape torch.Size([1, 500, 1000])
m0 = m0.nonzero() # valid points of shape torch.Size([valid_len, 3])
m0 = m0[:, 1:] # reshape to shape torch.Size([valid_len, 2])
with Timer():
one0 = torch.index_select(m0, 0, rand_one) # take 200k valid points
two0 = torch.index_select(m0, 0, rand_two) # take 200k valid points again
coor0 = torch.cat([one0, two0], dim=1) # stack them up
# >>> time used: 1.05s
# method two
m1 = m.reshape(-1) # reshape mask to torch.Size([500000])
m1 = m1==1 # mask of shape torch.Size([500000])
m1 = m1.nonzero() # valid points of shape torch.Size([valid_len, 1])
m1 = m1.reshape(-1) # valid points of shape torch.Size([valid_len])
with Timer():
one1 = m1.take(rand_one) # take 200k valid points
two1 = m1.take(rand_two) # again
# transform them to coordinates and stack them up
coor1 = torch.stack([one1 // 1000, one1 % 1000, two1 // 1000, two1 % 1000], dim=1)
# >>> time used: 0.07s
assert torch.sum(coor0 == coor1) == 800000 # make sure consistent result
ура