Проблема в функции backward () в настраиваемой функции потерь (autograd.Function) - PullRequest
0 голосов
/ 20 апреля 2019

Я работаю в двоичной классификации с 1d сверточной сетью. Для исследовательских целей я реализую свою собственную функцию потери, которая похожа на BCELoss. По этой причине я начал пытаться реализовать собственную функцию потери BCE с помощью autograd:

class DiscriminatorLoss(torch.autograd.Function):
          @staticmethod
          def forward(ctx,d_out,labels):
                loss = labels*torch.log(d_out)+(1-labels)*torch.log(1-d_out)
                ctx.d_out,ctx.labels = input,labels
                return loss 

          @staticmethod
          def backward(ctx, grad_output):
                d_out,labels = ctx.d_out,ctx.labels
                grad_input = -labels/d_out + ((1-labels)/(1-d_out))
                return grad_input,None 

Где d_out и labels являются тензорными:

d_out=tensor([[0.5412, 0.5225],     | labels=tensor([[0, 1], 
              [0.5486, 0.5167],     |                [0, 1],
              [0.5391, 0.5061],...])|                [0, 1],...])

Однако это не работает должным образом. Проблема в том, что в середине процесса обучения выходные данные сети (d_out) превращаются в странные значения, такие как:

      tensor([[9.9000e-08, 9.9000e-01],
              [9.9000e-08, 9.9000e-01],
              [9.9000e-08, 9.9000e-01],....])

И он застрянет там до конца тренировки.

Я также реализовал функцию BCELoss из Pytorch nn.BCELoss() (https://pytorch.org/docs/stable/nn.html#loss-functions).. И с этой функцией сеть работает, поэтому я считаю, что проблема в моей функции потерь. Точнее, форвард () работает хорошо, так как возвращает тот же убыток, что и nn.BCELoss. Так что проблема в backward () .

Кто-нибудь может мне помочь? Что я делаю неправильно в функции backward ()?

Спасибо!

PS .: выходы сети обрабатываются не точно 0 или 1, чтобы не генерировать значения NaN и -inf в перекрестной потере энтропии.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...