Взаимная информация между входным тренировочным изображением и разницей между двумя логитами - PullRequest
0 голосов
/ 21 апреля 2020

Я пытаюсь повторить эксперимент, проведенный Рафаэлем Мюллером и соавт. в « Когда помогает сглаживание меток? » NIPS, 2019

Метод в основном описан в разделе 4 и на рисунке 6 для оценки эффекта сглаживания меток при сетевой дистилляции с задачами классификации: " Мы измеряем взаимную информацию между X и Y, где X - это дискретная переменная, представляющая индекс обучающего примера, а Y - непрерывно, представляющая разницу между двумя логитами (вне классов K). Точная формула, используемая для приближения взаимной информации, записывается на странице 7.

Я пытаюсь повторить этот эксперимент, используя ResNet18 для набора данных Cifar-10.

После каждой эпохи я запускаю этот фрагмент кода для вычисления взаимной информации, как описано в Однако конечное значение вычисленной взаимной информации совершенно далеко от допустимого диапазона от 0 до log (N): размер пакета = 300, N = 600 (количество обучающих примеров, используемых в расчете взаимной инф.), L = 100 (образцы MonteCarlo), а класс__миссия представляет собой 1-й массив, содержащий Индексы двух классов, которые мы используем во время взаимной инф. расчет

with torch.no_grad():
    #Mean calculation
    for i in range(self.L):
        for batch_idx, (inputs, targets) in enumerate(trainloader_sub_transforms[i]):
            inputs, targets = inputs.to(self.device), targets.to(self.device)
            outputs = (self.net(inputs)).cpu().detach().numpy()
            outputs = np.absolute(outputs[:, classes_mi[0]] - outputs[:, classes_mi[1]])
            mu_x[batch_idx*300:batch_idx*300 + len(targets)] += outputs
    mu_x /= self.L
    print('--> Finish Mean Calculation ', mu_x[:10])
    #STD Calculation
    for batch_idx, (inputs, targets) in enumerate(trainloader_sub):
        inputs, targets = inputs.to(self.device), targets.to(self.device)
        outputs = (self.net(inputs)).cpu().detach().numpy()
        outputs = np.absolute(outputs[:, classes_mi[0]] - outputs[:, classes_mi[1]])
        var = np.sum((outputs - mu_x[batch_idx*300:batch_idx*300 + len(targets)]) ** 2)
    var /= self.N
    print('--> Finish VAR Calculation ', var)
    #Mutual Information Calculation
    mutual_info_value = np.zeros(self.N)
    term2 = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader_sub):
        print('batch_idx2 = ', batch_idx)
        inputs, targets = inputs.to(self.device), targets.to(self.device)
        outputs = (self.net(inputs)).cpu().detach().numpy()
        outputs = np.absolute(outputs[:, classes_mi[0]] - outputs[:, classes_mi[1]])
        main_term = -(outputs - mu_x[batch_idx*300:batch_idx*300+len(targets)])**2 / (2 * var)
        term2  +=  np.sum(np.exp( main_term ))
        mutual_info_value[batch_idx*300:batch_idx*300 + len(targets)] = main_term
    term2 = np.log(term2)
    mutual_info_value -= term2
    print(mutual_info_value.shape, mutual_info_value[:10])
    mutual_info_value = np.sum(mutual_info_value)
sum += mutual_info_value
print('trial: ', t, ' --> MI = ', mutual_info_value)
...