Может быть, это работает, но есть ли лучшее решение?
def mask_softmax(vec, mask):
leafs= vec.shape[0]
rows = vec.shape[1]
cols = vec.shape[2]
for k in range(leafs):
stop = int(mask[k])
for j in reversed(range(stop,cols)):
vec[k, :, j] = torch.zeros(rows) #all rows of col i <-- 0
vec = vec - torch.where(vec > 0,
torch.zeros_like(vec),
torch.ones_like(vec)*float('inf')) # switch 0 by -inf
# softmax(-inf) = nan
for k in range(leafs):
for i in range(rows):
vec[k,i] = F.softmax(vec[k, i], dim=0)
vec[vec != vec] = 0 # nan = 0
return vec
# testing
a = torch.rand((2,2,4))
mask = torch.Tensor((1,3))
mask_softmax(a, mask)
>>> tensor([[[0.5027, 0.4973, 0.0000, 0.0000],
[0.6494, 0.3506, 0.0000, 0.0000]],
[[0.3412, 0.3614, 0.2975, 0.0000],
[0.2699, 0.3978, 0.3323, 0.0000]]])