Я пытаюсь повторить эксперимент, проведенный Рафаэлем Мюллером и соавт. в « Когда помогает сглаживание меток? » NIPS, 2019
Метод в основном описан в разделе 4 и на рисунке 6 для оценки эффекта сглаживания меток при сетевой дистилляции с задачами классификации: " Мы измеряем взаимную информацию между X и Y, где X - это дискретная переменная, представляющая индекс обучающего примера, а Y - непрерывно, представляющая разницу между двумя логитами (вне классов K). Точная формула, используемая для приближения взаимной информации, записывается на странице 7.
Я пытаюсь повторить этот эксперимент, используя ResNet18 для набора данных Cifar-10.
После каждой эпохи я запускаю этот фрагмент кода для вычисления взаимной информации, как описано в Однако конечное значение вычисленной взаимной информации совершенно далеко от допустимого диапазона от 0 до log (N): размер пакета = 300, N = 600 (количество обучающих примеров, используемых в расчете взаимной инф.), L = 100 (образцы MonteCarlo), а класс__миссия представляет собой 1-й массив, содержащий Индексы двух классов, которые мы используем во время взаимной инф. расчет
with torch.no_grad():
#Mean calculation
for i in range(self.L):
for batch_idx, (inputs, targets) in enumerate(trainloader_sub_transforms[i]):
inputs, targets = inputs.to(self.device), targets.to(self.device)
outputs = (self.net(inputs)).cpu().detach().numpy()
outputs = np.absolute(outputs[:, classes_mi[0]] - outputs[:, classes_mi[1]])
mu_x[batch_idx*300:batch_idx*300 + len(targets)] += outputs
mu_x /= self.L
print('--> Finish Mean Calculation ', mu_x[:10])
#STD Calculation
for batch_idx, (inputs, targets) in enumerate(trainloader_sub):
inputs, targets = inputs.to(self.device), targets.to(self.device)
outputs = (self.net(inputs)).cpu().detach().numpy()
outputs = np.absolute(outputs[:, classes_mi[0]] - outputs[:, classes_mi[1]])
var = np.sum((outputs - mu_x[batch_idx*300:batch_idx*300 + len(targets)]) ** 2)
var /= self.N
print('--> Finish VAR Calculation ', var)
#Mutual Information Calculation
mutual_info_value = np.zeros(self.N)
term2 = 0
for batch_idx, (inputs, targets) in enumerate(trainloader_sub):
print('batch_idx2 = ', batch_idx)
inputs, targets = inputs.to(self.device), targets.to(self.device)
outputs = (self.net(inputs)).cpu().detach().numpy()
outputs = np.absolute(outputs[:, classes_mi[0]] - outputs[:, classes_mi[1]])
main_term = -(outputs - mu_x[batch_idx*300:batch_idx*300+len(targets)])**2 / (2 * var)
term2 += np.sum(np.exp( main_term ))
mutual_info_value[batch_idx*300:batch_idx*300 + len(targets)] = main_term
term2 = np.log(term2)
mutual_info_value -= term2
print(mutual_info_value.shape, mutual_info_value[:10])
mutual_info_value = np.sum(mutual_info_value)
sum += mutual_info_value
print('trial: ', t, ' --> MI = ', mutual_info_value)