Pytorch-Optimzer не обновляет параметры - PullRequest
0 голосов
/ 08 марта 2020
  1. Я сделал свою собственную модель AlexNetQIL (Ale xnet со слоем QIL). «QIL» означает изучение интервалов квантования

  2. Я обучил свою модель и потерю значение не уменьшилось вообще, и я обнаружил, что параметры в моей модели вообще не обновлялись из-за слоя QIL, который я добавил

  3. Я прикрепил свои коды AlexNetQil и qil, пожалуйста, дайте мне знать в чем проблема в моих кодах

AlexNetQIL

import torch
import torch.nn as nn
from qil import *

class AlexNetQIL(nn.Module):

    #def __init__(self, num_classes=1000): for imagenet
    def __init__(self, num_classes=10): # for cifar-10
        super(AlexNetQIL, self).__init__()

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu2 = nn.ReLU(inplace=True)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)

        self.qil2 = Qil()
        self.conv2 = nn.Conv2d(64, 192, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(192)
        self.relu2 = nn.ReLU(inplace=True)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2)

        self.qil3 = Qil()
        self.conv3 = nn.Conv2d(192, 384, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(384)
        self.relu3 = nn.ReLU(inplace=True)

        self.qil4 = Qil()
        self.conv4 = nn.Conv2d(384, 256, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        self.relu4 = nn.ReLU(inplace=True)

        self.qil5 = Qil()
        self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn5 = nn.BatchNorm2d(256)
        self.relu5 = nn.ReLU(inplace=True)
        self.maxpool5 = nn.MaxPool2d(kernel_size=2)

        self.classifier = nn.Sequential(
            nn.Linear(256 * 2 * 2, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )
    def forward(self,x,inference = False):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu2(x)
        x = self.maxpool1(x)

        x,self.conv2.weight = self.qil2(x,self.conv2.weight,inference ) # if I remove this line, No problem 
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.maxpool2(x)

        x,self.conv3.weight = self.qil3(x,self.conv3.weight,inference ) # if I remove this line, No problem 
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu3(x)

        x,self.conv4.weight = self.qil4(x,self.conv4.weight,inference ) # if I remove this line, No problem 
        x = self.conv4(x)
        x = self.bn4(x)
        x = self.relu4(x)

        x,self.conv5.weight = self.qil5(x,self.conv5.weight,inference ) # if I remove this line, No problem 
        x = self.conv5(x)
        x = self.bn5(x)
        x = self.relu5(x)
        x = self.maxpool5(x)
        x = x.view(x.size(0),256 * 2 * 2)
        x = self.classifier(x)
        return x

QIL

вперед

  • квантование весов и активация входа с 2 шагами
  • трансформатор (параметры) -> дискретизатор (параметры)
import torch
import torch.nn as nn
import numpy as np
import copy

#Qil (Quantize intervals learning)
class Qil(nn.Module):

    discretization_level = 32
    def __init__(self):
        super(Qil,self).__init__()
        self.cw = nn.Parameter(torch.rand(1)) # I have to train this interval parameter
        self.dw = nn.Parameter(torch.rand(1)) # I have to train this interval parameter
        self.cx = nn.Parameter(torch.rand(1)) # I have to train this interval parameter
        self.dx = nn.Parameter(torch.rand(1)) # I have to train this interval parameter
        self.gamma = nn.Parameter(torch.tensor(1.0))  # I have to train this transformer parameter

        self.a = Qil.discretization_level
    def forward(self,x,weights,Inference = False):
        if not Inference:
            weights = self.transfomer_weights(weights)
            weights = self.discretizer(weights)
        x = self.transfomer_activation(x)
        x = self.discretizer(x)
        return torch.nn.Parameter(x), torch.nn.Parameter(weights)

    def transfomer_weights(self,weights):
        device = weights.device
        aw,bw = (0.5 / self.dw) , (-0.5*self.cw / self.dw + 0.5)

        weights = torch.where( abs(weights) < self.cw - self.dw,
                                torch.tensor(0.).to(device),weights)
        weights = torch.where( abs(weights) > self.cw + self.dw,
                                weights.sign(), weights)
        weights = torch.where( (abs(weights) >= self.cw - self.dw) & (abs(weights) <= self.cw + self.dw),
                                (aw*abs(weights) + bw)**self.gamma * weights.sign() , weights)
        return weights

    def transfomer_activation(self,x):
        device = x.device
        ax,bx = (0.5 / self.dx) , (-0.5*self.cx / self.dx + 0.5)

        x = torch.where(x < self.cx - self.dx,
                        torch.tensor(0.).to(device),x)
        x = torch.where(x > self.cx + self.dx,
                        torch.tensor(1.0).to(device),x)
        x = torch.where( (abs(x) >= self.cx - self.dx) & (abs(x) <= self.cx + self.dx),
                            ax*abs(x) + bx, x)
        return x

    def discretizer(self,tensor):
        q_D = pow(2, Qil.discretization_level)
        tensor = torch.round(tensor * q_D) / q_D
        return tensor



...