Почему потери дистилляции не сходятся (уменьшаются) по сравнению с потерей перекрестной энтропии? - PullRequest
1 голос
/ 27 февраля 2020

Я использую кросс-энтропию && потери дистилляции для модели непрерывного обучения (также называемого инкрементным обучением). Но потери при дистилляции не сходятся, а потери при перекрестной энтропии сходятся.

Вот мой код потери при дистилляции и обучающей части.

Код потери при дистилляции:

def loss_fn_distillation(outputs, soft_labels, temperature, current_step, total_step, total_label):
    current_label = (total_label / total_step) * (current_step + 1)
    previous_label = (total_label / total_step) * current_step

    soft_labels = V(soft_labels.data, requires_grad=False).cuda()
    soft_labels = torch.softmax(soft_labels / temperature, dim=1)

     outputs = F.log_softmax(outputs[:,:-int(current_label-previous_label)]/temperature, dim = 1)

     distill_loss = torch.sum(outputs * soft_labels, dim=1, keepdim=False)
     distill_loss = -torch.mean(distill_loss, dim=0, keepdim=False)


      return V(distill_loss, requires_grad=True).cuda()

Код для обучающей части:

    outputs = net(inputs)
    ce_loss = criterion(outputs, targets)

    if(i>0) :

        soft_label = previous_net(inputs)

        distill_loss = loss_fn_distillation(outputs=outputs, soft_labels=soft_label, temperature=2,
                                            current_step=i, total_step=step, total_label=number_label)
        print(ce_loss, distill_loss)
        loss = distill_loss + ce_loss


    else :
        loss = ce_loss

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Результат для перекрестной потери энтропии && потери дистилляции за эпоху:

enter image description here

Буду признателен за любые отзывы , Спасибо.

...