Я сделал свою собственную модель AlexNetQIL (Ale xnet со слоем QIL). «QIL» означает изучение интервалов квантования
Я обучил свою модель и потерю значение не уменьшилось вообще, и я обнаружил, что параметры в моей модели вообще не обновлялись из-за слоя QIL, который я добавил
Я прикрепил свои коды AlexNetQil и qil, пожалуйста, дайте мне знать в чем проблема в моих кодах
AlexNetQIL
import torch
import torch.nn as nn
from qil import *
class AlexNetQIL(nn.Module):
#def __init__(self, num_classes=1000): for imagenet
def __init__(self, num_classes=10): # for cifar-10
super(AlexNetQIL, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1)
self.bn1 = nn.BatchNorm2d(64)
self.relu2 = nn.ReLU(inplace=True)
self.maxpool1 = nn.MaxPool2d(kernel_size=2)
self.qil2 = Qil()
self.conv2 = nn.Conv2d(64, 192, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(192)
self.relu2 = nn.ReLU(inplace=True)
self.maxpool2 = nn.MaxPool2d(kernel_size=2)
self.qil3 = Qil()
self.conv3 = nn.Conv2d(192, 384, kernel_size=3, padding=1)
self.bn3 = nn.BatchNorm2d(384)
self.relu3 = nn.ReLU(inplace=True)
self.qil4 = Qil()
self.conv4 = nn.Conv2d(384, 256, kernel_size=3, padding=1)
self.bn4 = nn.BatchNorm2d(256)
self.relu4 = nn.ReLU(inplace=True)
self.qil5 = Qil()
self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
self.bn5 = nn.BatchNorm2d(256)
self.relu5 = nn.ReLU(inplace=True)
self.maxpool5 = nn.MaxPool2d(kernel_size=2)
self.classifier = nn.Sequential(
nn.Linear(256 * 2 * 2, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes),
)
def forward(self,x,inference = False):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu2(x)
x = self.maxpool1(x)
x,self.conv2.weight = self.qil2(x,self.conv2.weight,inference ) # if I remove this line, No problem
x = self.conv2(x)
x = self.bn2(x)
x = self.relu2(x)
x = self.maxpool2(x)
x,self.conv3.weight = self.qil3(x,self.conv3.weight,inference ) # if I remove this line, No problem
x = self.conv3(x)
x = self.bn3(x)
x = self.relu3(x)
x,self.conv4.weight = self.qil4(x,self.conv4.weight,inference ) # if I remove this line, No problem
x = self.conv4(x)
x = self.bn4(x)
x = self.relu4(x)
x,self.conv5.weight = self.qil5(x,self.conv5.weight,inference ) # if I remove this line, No problem
x = self.conv5(x)
x = self.bn5(x)
x = self.relu5(x)
x = self.maxpool5(x)
x = x.view(x.size(0),256 * 2 * 2)
x = self.classifier(x)
return x
QIL
вперед
- квантование весов и активация входа с 2 шагами
- трансформатор (параметры) -> дискретизатор (параметры)
import torch
import torch.nn as nn
import numpy as np
import copy
#Qil (Quantize intervals learning)
class Qil(nn.Module):
discretization_level = 32
def __init__(self):
super(Qil,self).__init__()
self.cw = nn.Parameter(torch.rand(1)) # I have to train this interval parameter
self.dw = nn.Parameter(torch.rand(1)) # I have to train this interval parameter
self.cx = nn.Parameter(torch.rand(1)) # I have to train this interval parameter
self.dx = nn.Parameter(torch.rand(1)) # I have to train this interval parameter
self.gamma = nn.Parameter(torch.tensor(1.0)) # I have to train this transformer parameter
self.a = Qil.discretization_level
def forward(self,x,weights,Inference = False):
if not Inference:
weights = self.transfomer_weights(weights)
weights = self.discretizer(weights)
x = self.transfomer_activation(x)
x = self.discretizer(x)
return torch.nn.Parameter(x), torch.nn.Parameter(weights)
def transfomer_weights(self,weights):
device = weights.device
aw,bw = (0.5 / self.dw) , (-0.5*self.cw / self.dw + 0.5)
weights = torch.where( abs(weights) < self.cw - self.dw,
torch.tensor(0.).to(device),weights)
weights = torch.where( abs(weights) > self.cw + self.dw,
weights.sign(), weights)
weights = torch.where( (abs(weights) >= self.cw - self.dw) & (abs(weights) <= self.cw + self.dw),
(aw*abs(weights) + bw)**self.gamma * weights.sign() , weights)
return weights
def transfomer_activation(self,x):
device = x.device
ax,bx = (0.5 / self.dx) , (-0.5*self.cx / self.dx + 0.5)
x = torch.where(x < self.cx - self.dx,
torch.tensor(0.).to(device),x)
x = torch.where(x > self.cx + self.dx,
torch.tensor(1.0).to(device),x)
x = torch.where( (abs(x) >= self.cx - self.dx) & (abs(x) <= self.cx + self.dx),
ax*abs(x) + bx, x)
return x
def discretizer(self,tensor):
q_D = pow(2, Qil.discretization_level)
tensor = torch.round(tensor * q_D) / q_D
return tensor