Градиентное поведение в pytorch с многослойной потерей - PullRequest
0 голосов
/ 10 мая 2019

У меня есть потеря, когда каждый слой играет на потерю.Какой правильный подход с точки зрения обеспечения правильного обновления весов?

# option 1
x2 = self.layer1(x1)
x3 = self.layer2(x2)
x4 = self.layer3(x3)

В этом варианте я отключаюсь при подаче в каждый последующий блок

    # option 2
    # x2 = self.layer1(x1.detach())
    # x3 = self.layer2(x2.detach())
    # x4 = self.layer3(x3.detach())

общих операций, которые вычисляют4 потери и сложите их.

    x4 = F.relu(self.bn1(x4))
    loss = some_loss([x1, x2, x3, x4])

1 Ответ

0 голосов
/ 12 мая 2019

Вариант 1 правильный. При отсоединении тензора история / график вычислений теряется, и градиенты не будут распространяться на входы / для вычислений, выполняемых перед отсоединением.

Это также можно увидеть в этом игрушечном эксперименте.

In [14]: import torch                                                                                                                                                                                 

In [15]: x = torch.rand(10,10).requires_grad_()                                                                                                                                                       

In [16]: y = x**2                                                                                                                                                                                     

In [19]: z = torch.sum(y)                                                                                                                                                                             

In [20]: z.backward()                                                                                                                                                                                 

In [23]: x.grad is not None                                                                                                                                                                           
Out[23]: True

Использование detach

In [26]: x = torch.rand(10,10).requires_grad_()                                                                                                                                                       

In [27]: y = x**2                                                                                                                                                                                     

In [28]: z = torch.sum(y)                                                                                                                                                                             

In [29]: z_ = z.detach()                                                                                                                                                                              

In [30]: z_.backward()  
# this gives error

Это потому, что когда вы вызываете detach, он возвращает новый тензор со скопированными значениями и информация о предыдущих вычислениях теряется.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...