Я пытаюсь квантовать модель, в которой используется PReLU
. Замена PReLU
на ReLU
невозможна, так как это резко влияет на производительность сети до такой степени, что становится бесполезным.
Насколько мне известно, PReLU
не поддерживается в Pytorch, когда дело доходит до квантования. Поэтому я попытался переписать этот модуль вручную и реализовать умножение и сложение, используя torch.FloatFunctional()
, чтобы обойти это ограничение.
Это то, что я придумал до сих пор:
class PReLU_Quantized(nn.Module):
def __init__(self, prelu_object):
super().__init__()
self.weight = prelu_object.weight
self.quantized_op = nn.quantized.FloatFunctional()
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, inputs):
# inputs = torch.max(0, inputs) + self.weight * torch.min(0, inputs)
self.weight = self.quant(self.weight)
weight_min_res = self.quantized_op.mul(self.weight, torch.min(inputs)[0])
inputs = self.quantized_op.add(torch.max(inputs)[0], weight_min_res).unsqueeze(0)
self.weight = self.dequant(self.weight)
return inputs
и для замены:
class model(nn.Module):
def __init__(self)
super().__init__()
....
self.prelu = PReLU()
self.prelu_q = PReLU_Quantized(self.prelu)
....
В основном, я читаю заученный параметр существующего модуля prelu и сам запускаю расчет в новом модуле. Кажется, что модуль работает в том смысле, что он не дает сбой всего приложения.
Однако, чтобы оценить, действительно ли моя реализация верна и дает ли тот же результат, что и исходный модуль, я попытался протестировать ее. . Вот аналог для обычных моделей (то есть не квантованной модели): По какой-то причине ошибка между фактическим PReLU
и моей реализацией очень велика!
Вот примеры различий в разных слоях:
diff : 1.1562038660049438
diff : 0.02868632599711418
diff : 0.3653906583786011
diff : 1.6100226640701294
diff : 0.8999372720718384
diff : 0.03773299604654312
diff : -0.5090572834014893
diff : 0.1654307246208191
diff : 1.161868691444397
diff : 0.026089997962117195
diff : 0.4205571115016937
diff : 1.5337920188903809
diff : 0.8799554705619812
diff : 0.03827812895178795
diff : -0.40296515822410583
diff : 0.15618863701820374
и разница вычисляется следующим образом в прямой проход:
def forward(self, x):
residual = x
out = self.bn0(x)
out = self.conv1(out)
out = self.bn1(out)
out = self.prelu(out)
out2 = self.prelu2(out)
print(f'diff : {( out - out2).mean().item()}')
out = self.conv2(out)
...
Это обычная реализация, которую я использовал на обычной модели (т.е. не квантованной!), чтобы оценить, дает ли она правильный результат, а затем перейти к квантованной версии:
class PReLU_2(nn.Module):
def __init__(self, prelu_object):
super().__init__()
self.prelu_weight = prelu_object.weight
self.weight = self.prelu_weight
def forward(self, inputs):
x = self.weight
tmin, _ = torch.min(inputs,dim=0)
tmax, _ = torch.max(inputs,dim=0)
weight_min_res = torch.mul(x, tmin)
inputs = torch.add(tmax, weight_min_res)
inputs = inputs.unsqueeze(0)
return inputs
Что мне здесь не хватает?