Как сделать встраивание флага одной строки в PyTorch (не nn.Embedding)? - PullRequest
0 голосов
/ 07 июля 2019

С помощью генератора я создаю случайную партию, как:

import torch

n = 10
batch_size = 2

x = torch.zeros((batch_size, n), dtype=torch.float)
in_flags = torch.randint(n, (batch_size,), dtype=torch.long)

for idx, row in enumerate(x):
    row[in_flags[idx]] = 1.0

Но недостатком этого является то, что цикл выполняется в Python.Это первоначальный смысл встраивания (не путайте это с PyTorch nn.embedding).Можно ли сделать один оператор PyTorch, чтобы он выполнялся как собственный, так и в графическом процессоре?

1 Ответ

0 голосов
/ 08 июля 2019

Вы можете сделать так:

import torch

n = 10
batch_size = 2

in_flags = torch.randint(n, (batch_size,), dtype=torch.long)
x = torch.zeros((batch_size, n), dtype=torch.float)

# this is how you can do this
x[torch.arange(batch_size), in_flags] = 1.0

print(in_flags)
print(x)

Выход:

tensor([8, 0])
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
...