Как я могу включить PReLU в квантованную модель? - PullRequest
1 голос
/ 14 июля 2020

Я пытаюсь квантовать модель, в которой используется PReLU. Замена PReLU на ReLU невозможна, так как это резко влияет на производительность сети до такой степени, что становится бесполезным.

Насколько мне известно, PReLU не поддерживается в Pytorch, когда дело доходит до квантования. Поэтому я попытался переписать этот модуль вручную и реализовать умножение и сложение, используя torch.FloatFunctional(), чтобы обойти это ограничение.

Это то, что я придумал до сих пор:

class PReLU_Quantized(nn.Module):
    def __init__(self, prelu_object):
        super().__init__()
        self.weight = prelu_object.weight
        self.quantized_op = nn.quantized.FloatFunctional()
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, inputs):
        # inputs = torch.max(0, inputs) + self.weight * torch.min(0, inputs)    
        self.weight = self.quant(self.weight)
        weight_min_res = self.quantized_op.mul(self.weight, torch.min(inputs)[0])
        inputs = self.quantized_op.add(torch.max(inputs)[0], weight_min_res).unsqueeze(0)
        self.weight = self.dequant(self.weight)
        return inputs

и для замены:

class model(nn.Module):
     def __init__(self)
         super().__init__()
         .... 
        self.prelu = PReLU()
        self.prelu_q = PReLU_Quantized(self.prelu)
         ....

В основном, я читаю заученный параметр существующего модуля prelu и сам запускаю расчет в новом модуле. Кажется, что модуль работает в том смысле, что он не дает сбой всего приложения.

Однако, чтобы оценить, действительно ли моя реализация верна и дает ли тот же результат, что и исходный модуль, я попытался протестировать ее. . Вот аналог для обычных моделей (то есть не квантованной модели): По какой-то причине ошибка между фактическим PReLU и моей реализацией очень велика!

Вот примеры различий в разных слоях:

diff : 1.1562038660049438
diff : 0.02868632599711418
diff : 0.3653906583786011
diff : 1.6100226640701294
diff : 0.8999372720718384
diff : 0.03773299604654312
diff : -0.5090572834014893
diff : 0.1654307246208191
diff : 1.161868691444397
diff : 0.026089997962117195
diff : 0.4205571115016937
diff : 1.5337920188903809
diff : 0.8799554705619812
diff : 0.03827812895178795
diff : -0.40296515822410583
diff : 0.15618863701820374

и разница вычисляется следующим образом в прямой проход:

def forward(self, x):
    residual = x
    out = self.bn0(x)
    out = self.conv1(out)
    out = self.bn1(out)

    out = self.prelu(out)
    out2 = self.prelu2(out)
    print(f'diff : {( out - out2).mean().item()}')

    out = self.conv2(out)
...

Это обычная реализация, которую я использовал на обычной модели (т.е. не квантованной!), чтобы оценить, дает ли она правильный результат, а затем перейти к квантованной версии:

class PReLU_2(nn.Module):
    def __init__(self, prelu_object):
        super().__init__()
        self.prelu_weight = prelu_object.weight
        self.weight = self.prelu_weight

    def forward(self, inputs):
        x = self.weight
        tmin, _ = torch.min(inputs,dim=0)
        tmax, _ = torch.max(inputs,dim=0)
        weight_min_res = torch.mul(x, tmin)
        inputs = torch.add(tmax, weight_min_res)
        inputs = inputs.unsqueeze(0)
        return inputs

Что мне здесь не хватает?

1 Ответ

2 голосов
/ 14 июля 2020

Разобрался! Я совершил огромную ошибку в самом начале. Мне нужно было вычислить

PReLU(x)=max(0,x)+a∗min(0,x)

или enter image description here
and not the actual torch.min! or torch.max! which doesn't make any sense! Here is the final solution for normal models (i.e not quantized)!:

class PReLU_2(nn.Module):
    def __init__(self, prelu_object):
        super().__init__()
        self.prelu_weight = prelu_object.weight
        self.weight = self.prelu_weight

    def forward(self, inputs):
        pos = torch.relu(inputs)
        neg = -self.weight * torch.relu(-inputs)
        inputs = pos + neg
        return inputs

и это квантованная версия:

class PReLU_Quantized(nn.Module):
    def __init__(self, prelu_object):
        super().__init__()
        self.prelu_weight = prelu_object.weight
        self.weight = self.prelu_weight
        self.quantized_op = nn.quantized.FloatFunctional()
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, inputs):
        # inputs = max(0, inputs) + alpha * min(0, inputs) 
        self.weight = self.quant(self.weight)
        weight_min_res = self.quantized_op.mul(-self.weight, torch.relu(-inputs))
        inputs = self.quantized_op.add(torch.relu(inputs), weight_min_res)
        inputs = self.dequant(inputs)
        self.weight = self.dequant(self.weight)
        return inputs

Примечание: У меня также была опечатка, когда я вычислял разницу:

    out = self.prelu(out)
    out2 = self.prelu2(out)
    print(f'diff : {( out - out2).mean().item()}')

    out = self.conv2(out)

должно быть

    out1 = self.prelu(out)
    out2 = self.prelu2(out)
    print(f'diff : {( out1 - out2).mean().item()}')
    out = self.conv2(out1)

Обновление:

Если вы столкнулись с проблемами при квантовании, вы Вы можете попробовать эту версию :

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.quantized as nnq
from torch.quantization import fuse_modules


class QPReLU(nn.Module):
    def __init__(self, num_parameters=1, init: float = 0.25):
        super(QPReLU, self).__init__()
        self.num_parameters = num_parameters
        self.weight = nn.Parameter(torch.Tensor(num_parameters).fill_(init))
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.f_mul_neg_one1 = nnq.FloatFunctional()
        self.f_mul_neg_one2 = nnq.FloatFunctional()
        self.f_mul_alpha = nnq.FloatFunctional()
        self.f_add = nnq.FloatFunctional()
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
        self.quant2 = torch.quantization.QuantStub()
        self.quant3 = torch.quantization.QuantStub()
        # self.dequant2 = torch.quantization.QuantStub()
        self.neg_one = torch.Tensor([-1.0])
        
    
    def forward(self, x):
        x = self.quant(x)
        
        # PReLU, with modules only
        x1 = self.relu1(x)
        
        neg_one_q = self.quant2(self.neg_one)
        weight_q = self.quant3(self.weight)
        x2 = self.f_mul_alpha.mul(
            weight_q, self.f_mul_neg_one2.mul(
                self.relu2(
                    self.f_mul_neg_one1.mul(x, neg_one_q),
                ),
            neg_one_q)
        )
        
        x = self.f_add.add(x1, x2)
        x = self.dequant(x)
        return x
    
m1 = nn.PReLU()
m2 = QPReLU()

# check correctness in fp
for i in range(10):
    data = torch.randn(2, 2) * 1000
    assert torch.allclose(m1(data), m2(data))

# toy model
class M(nn.Module):
    def __init__(self):
        super(M, self).__init__()
        self.prelu = QPReLU()
        
    def forward(self, x):
        x = self.prelu(x)
        return x
    
# quantize it
m = M()
m.qconfig = torch.quantization.default_qconfig
torch.quantization.prepare(m, inplace=True)
# calibrate
m(torch.randn(4, 4))
# convert
torch.quantization.convert(m, inplace=True)
# run some data through
res = m(torch.randn(4, 4))
print(res)

и обязательно прочтите сводные примечания здесь

...