Как изменить одно значение тензора на ноль в соответствии с другим значением тензора в pytorch? - PullRequest
3 голосов
/ 25 мая 2020

У меня есть два тензора: тензор a и тензор b. Как я могу изменить какое-либо значение тензора a в соответствии со значением тензора b?

Я знаю, что следующие коды верны, но при большом тензоре он работает довольно медленно. Есть ли другой способ?

import torch
a = torch.rand(10).cuda()
b = torch.rand(10).cuda()
a[b > 0.5] = 0.

Ответы [ 2 ]

1 голос
/ 25 мая 2020

Для этого конкретного варианта использования также рассмотрите

a * (b <= 0.5)

, который кажется самым быстрым из следующих

In [1]: import torch
   ...: a = torch.rand(3**10)
   ...: b = torch.rand(3**10)

In [2]: %timeit a[b > 0.5] = 0.
553 µs ± 17.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In [3]: a = torch.rand(3**10)

In [4]: %timeit temp = torch.where(b > 0.5, torch.tensor(0.), a)
   ...:
49 µs ± 391 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

In [5]: a = torch.rand(3**10)

In [6]: %timeit temp = (a * (b <= 0.5))
44 µs ± 381 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

In [7]: %timeit a.masked_fill_(b > 0.5, 0.)
244 µs ± 3.48 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1 голос
/ 25 мая 2020

Думаю, torch.where будет быстрее. У меня есть измерения в ЦП, вот результат.

import torch
a = torch.rand(3**10)
b = torch.rand(3**10)
%timeit a[b > 0.5] = 0.
852 µs ± 30.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit temp = torch.where(b > 0.5, torch.tensor(0.), a)
294 µs ± 4.51 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
...