Вы можете сделать что-то вроде этого, где:
m
- ваша маска; x
- ваша спектрограмма; o
ваш вывод;
import torch
torch.manual_seed(2020)
m = torch.tensor([[0, 1, 0]]).to(torch.int32)
x = torch.rand((1, 3, 2))
o = torch.rand((1, 3, 2))
print(o)
# tensor([[[0.5899, 0.8105],
# [0.2512, 0.6307],
# [0.5403, 0.8033]]])
print(x)
# tensor([[[0.4869, 0.1052],
# [0.5883, 0.1161],
# [0.4949, 0.2824]]])
o[:, m[0].to(torch.bool), :] = x[:, m[0].to(torch.bool), :]
# or
# o[:, m[0] == 1, :] = x[:, m[0] == 1, :]
print(o)
# tensor([[[0.5899, 0.8105],
# [0.5883, 0.1161],
# [0.5403, 0.8033]]])