Как установить маску в отдельном слое внимания в pytorch - PullRequest
0 голосов
/ 08 ноября 2019

Сейчас я учусь методу Внимание в pytorch,

Я не могу решить сделать тензор маски в отдельном слое внимания из-за ошибки измерения. Как получить маску тензор_сайза перед уровнем внимания ??

, например, при вводе размера тензора: [1,25,10], правильно ли это подготовить размер тензора: [25,25] для размера внимания_карты?

Есть ли способ заранее сделать маску слоя внимания ??

Я имею в виду коды def и class в https://github.com/YutaroOgawa/pytorch_advanced/blob/master/7_nlp_sentiment_transformer/7-6_Transformer.ipynb

Извините за неудобства, Можете ли вы датьмне совет?

# single attention code is below

class Attention(nn.Module):

    def __init__(self, d_model=300):
        super().__init__()

        self.q_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)

        # output
        self.out = nn.Linear(d_model, d_model)

        self.d_k = d_model

    def forward(self, q, k, v, mask):

        k = self.k_linear(k)
        q = self.q_linear(q)
        v = self.v_linear(v)

        weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.d_k)

        mask = mask.unsqueeze(1)
        weights = weights.masked_fill(mask == 0, -1e9)

        # softmax
        normlized_weights = F.softmax(weights, dim=-1)

        output = torch.matmul(normlized_weights, v)

        output = self.out(output)

        return output, normlized_weights

#example
# test Attention Output

test_array=np.arange(50)
test_inputs=torch.Tensor(test_array).view(1,5,10)
print(test_inputs)
#torch.Size([5, 10])

# test mask 
inputs_mask= (test_inputs<5)

print(inputs_mask.size())
#mask:torch.Size([1, 5, 10])

atten=Attention(d_model=10)
output_attention,_=atten(q=test_inputs,k=test_inputs,v=test_inputs,mask=inputs_mask)

#error: The size of tensor a (10) must match the size of tensor b (5) at non-singleton dimension 3**strong text**

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...