Многомерная случайная отрисовка без замены на предварительные выборки в pytorch - PullRequest
1 голос
/ 19 июня 2020

У меня есть тензор (N, I) из N строк с индексами I между 0 и Z, например, N=5, I=3, Z=100:

foo = tensor([[83,  5, 85],
              [ 7, 60, 66],
              [89, 25, 63],
              [58, 67, 47],
              [12, 46, 40]], device='cuda:0')

Теперь я хочу эффективно добавить X случайные дополнительные новые индексы (т. Е. Еще не включенные в тензор!) Между 0 и Z к тензору, например:

foo_new = tensor([[83,  5, 85,  9, 43, 53, 42],
                  [ 7, 60, 66, 85, 64, 22,  1],
                  [89, 25, 63, 38, 24,  4, 75],
                  [58, 67, 47, 83, 43, 29, 55],
                  [12, 46, 40, 74, 21, 11, 52]], device='cuda:0')

Тензор будет в конце имеют в каждой строке I+X уникальные индексы между 0 и Z, где I индексы - это индексы из начального тензора, а X индексы равномерно выбираются случайным образом без замены из остальных индексов {0...Z}\{I(n)}, где {I(n)} - индексы n-й строки.

Так что это похоже на многомерное случайное рисование без замены индексов 0 на Z, где первый I dr aws (в каждой строке) принудительно приводятся к индексам, заданным исходным тензором.

Как мне сделать это эффективно, особенно с потенциально большими Z?

Что я пробовал до сих пор (что было довольно медленно):

device = torch.cuda.current_device()
notinfoo = torch.ones((N,I), device=device).byte()
N_row = torch.arange(N, device=device).unsqueeze(dim=-1)
notinfoo[N_row, foo] = 0
foo_new = torch.stack([torch.cat((f, torch.arange(Z, device=device)[nf][torch.randperm(Z-I, device=device)][:X])) for f,nf in zip(foo,notinfoo)])

1 Ответ

0 голосов
/ 19 июня 2020

Используйте сначала numpy numpy.random.choice, чтобы получить образцы с replace=False без замены выборки. а затем объедините оба, используя torch.cat

import numpy as np
foo_new = torch.tensor(np.random.choice(100 , (5,4), replace=False))   # Z = 100
foo_new = torch.cat((foo, foo_new), 1)

foo_new
tensor([[83,  5, 85, 56, 83, 16, 20],
        [ 7, 60, 66, 43, 31, 75, 67],
        [89, 25, 63, 96,  3, 13, 11],
        [58, 67, 47, 55, 92, 70, 35],
        [12, 46, 40, 79, 61, 58, 76]])

...