Я не думаю, что это напрямую возможно реализовать с использованием существующего InstanceNorm1d
, проще всего было бы реализовать его самостоятельно с нуля. Я сделал быструю реализацию, которая должна работать. Чтобы сделать его немного более общим, этому модулю требуется логическая маска (булевский тензор того же размера, что и входные данные), который указывает, какие элементы следует учитывать при прохождении через экземплярную норму.
import torch
class MaskedInstanceNorm1d(torch.nn.Module):
def __init__(self, num_features, eps=1e-6, momentum=0.1, affine=True, track_running_stats=False):
super().__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
self.gamma = None
self.beta = None
if self.affine:
self.gamma = torch.nn.Parameter(torch.ones((1, self.num_features, 1), requires_grad=True))
self.beta = torch.nn.Parameter(torch.zeros((1, self.num_features, 1), requires_grad=True))
self.running_mean = None
self.running_variance = None
if self.affine:
self.running_mean = torch.zeros((1, self.num_features, 1), requires_grad=True)
self.running_variance = torch.zeros((1, self.num_features, 1), requires_grad=True)
def forward(self, x, mask):
mean = torch.zeros((1, self.num_features, 1), requires_grad=False)
variance = torch.ones((1, self.num_features, 1), requires_grad=False)
# compute masked mean and variance of batch
for c in range(self.num_features):
if mask[:, c, :].any():
mean[0, c, 0] = x[:, c, :][mask[:, c, :]].mean()
variance[0, c, 0] = (x[:, c, :][mask[:, c, :]] - mean[0, c, 0]).pow(2).mean()
# update running mean and variance
if self.training and self.track_running_stats:
for c in range(self.num_features):
if mask[:, c, :].any():
self.running_mean[0, c, 0] = (1-self.momentum) * self.running_mean[0, c, 0] \
+ self.momentum * mean[0, c, 0]
self.running_variance[0, c, 0] = (1-self.momentum) * self.running_variance[0, c, 0] \
+ self.momentum * variance[0, c, 0]
# compute output
x = (x - mean)/(self.eps + variance).sqrt()
if self.affine:
x = x * self.gamma + self.beta
return x