Я использую кросс-энтропию && потери дистилляции для модели непрерывного обучения (также называемого инкрементным обучением). Но потери при дистилляции не сходятся, а потери при перекрестной энтропии сходятся.
Вот мой код потери при дистилляции и обучающей части.
Код потери при дистилляции:
def loss_fn_distillation(outputs, soft_labels, temperature, current_step, total_step, total_label):
current_label = (total_label / total_step) * (current_step + 1)
previous_label = (total_label / total_step) * current_step
soft_labels = V(soft_labels.data, requires_grad=False).cuda()
soft_labels = torch.softmax(soft_labels / temperature, dim=1)
outputs = F.log_softmax(outputs[:,:-int(current_label-previous_label)]/temperature, dim = 1)
distill_loss = torch.sum(outputs * soft_labels, dim=1, keepdim=False)
distill_loss = -torch.mean(distill_loss, dim=0, keepdim=False)
return V(distill_loss, requires_grad=True).cuda()
Код для обучающей части:
outputs = net(inputs)
ce_loss = criterion(outputs, targets)
if(i>0) :
soft_label = previous_net(inputs)
distill_loss = loss_fn_distillation(outputs=outputs, soft_labels=soft_label, temperature=2,
current_step=i, total_step=step, total_label=number_label)
print(ce_loss, distill_loss)
loss = distill_loss + ce_loss
else :
loss = ce_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
Результат для перекрестной потери энтропии && потери дистилляции за эпоху:
Буду признателен за любые отзывы , Спасибо.