Pytorch: умножьте два тензора высокой размерности (2, 5, 3) * (2, 5) на (2, 5, 3) - PullRequest
0 голосов
/ 13 июля 2020

Я хочу умножить два тензора высокой размерности (2, 5, 3) * (2, 5) на (2, 5, 3), которые умножают каждый вектор-строку на скаляр.

Например,

emb = nn.Embedding(6, 3)

input = torch.tensor([[1, 2, 3, 4, 5,],
                      [2, 3, 1, 4, 5,]])
input_emb = emb(input)


print(input.shape)
> torch.Size([2, 5])

print(input_emb.shape)
> torch.Size([2, 5, 3])

print(input_emb)
> tensor([[[-1.9114, -0.1580,  1.2186],
         [ 0.4627,  0.9119, -1.1691],
         [ 0.6452, -0.6944,  1.9659],
         [-0.5048,  0.6411, -1.3568],
         [-0.2328, -0.9498,  0.7216]],

        [[ 0.4627,  0.9119, -1.1691],
         [ 0.6452, -0.6944,  1.9659],
         [-1.9114, -0.1580,  1.2186],
         [-0.5048,  0.6411, -1.3568],
         [-0.2328, -0.9498,  0.7216]]], grad_fn=<EmbeddingBackward>)

Я хочу умножить, может следующим образом:

// It is written in this way for convenience, not mathematical true. 

// multiply each row vector by a scalar
[[
         [-1.9114, -0.1580,  1.2186] * 1
         [ 0.4627,  0.9119, -1.1691] * 2
         [ 0.6452, -0.6944,  1.9659] * 3
         [-0.5048,  0.6411, -1.3568] * 4
         [-0.2328, -0.9498,  0.7216] * 5
] 
[
         [ 0.4627,  0.9119, -1.1691] * 2
         [ 0.6452, -0.6944,  1.9659] * 3
         [-1.9114, -0.1580,  1.2186] * 1
         [-0.5048,  0.6411, -1.3568] * 4
         [-0.2328, -0.9498,  0.7216] * 5
]]

За исключением способов нескольких циклов, как кратко реализовать это с помощью PyTorch API?
Заранее благодарим.

1 Ответ

1 голос
/ 13 июля 2020

Можно, правильно совместив размеры обоих тензоров:

import torch
from torch.nn import Embedding

emb = Embedding(6, 3)
inp = torch.tensor([[1, 2, 3, 4, 5,],
                      [2, 3, 1, 4, 5,]])
input_emb = emb(inp)

inp[...,None] * input_emb

tensor([[[-0.3069, -0.7727, -0.3772],
         [-2.8308,  1.3438, -1.1167],
         [ 0.6366,  0.6509, -3.2282],
         [-4.3004,  3.2342, -0.6556],
         [-3.0045, -0.0191, -7.4436]],

        [[-2.8308,  1.3438, -1.1167],
         [ 0.6366,  0.6509, -3.2282],
         [-0.3069, -0.7727, -0.3772],
         [-4.3004,  3.2342, -0.6556],
         [-3.0045, -0.0191, -7.4436]]], grad_fn=<MulBackward0>)
...