pytorch masked_fill: почему я не могу замаскировать все нули? - PullRequest
0 голосов
/ 27 мая 2019

Я хочу замаскировать все нули в матрице оценок с помощью -np.inf, но я могу замаскировать только часть нулей, выглядело как

enter image description here

Вы видите в правом верхнем углу все еще нули, которые не были замаскированы с -np.inf

Вот мои коды:

q = torch.Tensor([np.random.random(10),np.random.random(10),np.random.random(10), np.random.random(10), np.zeros((10,1)), np.zeros((10,1))])
k = torch.Tensor([np.random.random(10),np.random.random(10),np.random.random(10), np.random.random(10), np.zeros((10,1)), np.zeros((10,1))])
scores = torch.matmul(q, k.transpose(0,1)) / math.sqrt(10)
mask = torch.Tensor([1,1,1,1,0,0])
mask = mask.unsqueeze(1)
scores = scores.masked_fill(mask==0, -np.inf)

Может, маска неправильная?

1 Ответ

2 голосов
/ 27 мая 2019

Ваша маска неверна. Попробуйте

scores = scores.masked_fill(scores == 0, -np.inf)

scores теперь выглядит как

tensor([[1.4796, 1.2361, 1.2137, 0.9487,   -inf,   -inf],
        [0.6889, 0.4428, 0.6302, 0.4388,   -inf,   -inf],
        [0.8842, 0.7614, 0.8311, 0.6431,   -inf,   -inf],
        [0.9884, 0.8430, 0.7982, 0.7323,   -inf,   -inf],
        [  -inf,   -inf,   -inf,   -inf,   -inf,   -inf],
        [  -inf,   -inf,   -inf,   -inf,   -inf,   -inf]])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...