Pytorch: «Вес модели не меняется» - PullRequest
0 голосов
/ 24 ноября 2018

Может кто-нибудь помочь мне понять, почему веса не обновляются?

    unet = Unet()
    optimizer = torch.optim.Adam(unet.parameters(), lr=0.001)
    loss_fn = torch.nn.MSELoss()
    input =  Variable(torch.randn(32, 1, 64, 64, 64 ), requires_grad=True)
    target = Variable(torch.randn(32, 1, 64, 64, 64), requires_grad=False)

    optimizer.zero_grad()
    y_pred = unet(input)
    y = target[: , : , 20:44, 20:44, 20:44]

    loss = loss_fn(y_pred, y)
    print(unet.conv1.weight.data[0][0]) # weights of the first layer in the unet
    loss.backward()
    optimizer.step()
    print(unet.conv1.weight.data[0][0]) # weights havent changed

Модель определяется как:

class Unet(nn.Module):

def __init__(self):
  super(Unet, self).__init__()

  # Down hill1
  self.conv1 = nn.Conv3d(1, 2, kernel_size=3,  stride=1)
  self.conv2 = nn.Conv3d(2, 2, kernel_size=3,  stride=1)

  # Down hill2
  self.conv3 = nn.Conv3d(2, 4, kernel_size=3,  stride=1)
  self.conv4 = nn.Conv3d(4, 4, kernel_size=3,  stride=1)

  #bottom
  self.convbottom1 = nn.Conv3d(4, 8, kernel_size=3,  stride=1)
  self.convbottom2 = nn.Conv3d(8, 8, kernel_size=3,  stride=1)

  #up hill1
  self.upConv0 = nn.Conv3d(8, 4, kernel_size=3,  stride=1)
  self.upConv1 = nn.Conv3d(4, 4, kernel_size=3,  stride=1)
  self.upConv2 = nn.Conv3d(4, 2, kernel_size=3,  stride=1)

  #up hill2
  self.upConv3 = nn.Conv3d(2, 2, kernel_size=3, stride=1)
  self.upConv4 = nn.Conv3d(2, 1, kernel_size=1, stride=1)

  self.mp = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
  # some more irrelevant properties...

Функция пересылки выглядит следующим образом:

def forward(self, input):
    # Use U-net Theory to Update the filters.
    # Example Approach...
    input = F.relu(self.conv1(input))
    input = F.relu(self.conv2(input))

    input = self.mp(input)

    input = F.relu(self.conv3(input))
    input = F.relu(self.conv4(input))

    input = self.mp(input)

    input = F.relu(self.convbottom1(input))
    input = F.relu(self.convbottom2(input))

    input = F.interpolate(input, scale_factor=2, mode='trilinear')

    input = F.relu(self.upConv0(input))
    input = F.relu(self.upConv1(input))

    input = F.interpolate(input, scale_factor=2, mode='trilinear')


    input = F.relu(self.upConv2(input))
    input = F.relu(self.upConv3(input))

    input = F.relu(self.upConv4(input))

    return input

Я следовал подходу любого примера и документации, которую смог найти, и мне кажется, почему это не работает?

Я могу понять, что y_pred.grad после обратного вызова - это не то, что не должно быть.Если у нас нет градиента, то, конечно, оптимизатор не может изменять веса в любом направлении, но почему нет градиента?

Ответы [ 2 ]

0 голосов
/ 06 марта 2019

Я определил, что эта проблема относится к «проблеме умирающего ReLu». Поскольку данные были единицами Хаунсфилда, и равномерное распределение начальных весов в Pytorch означало, что многие нейроны начинали в нулевой области ReLu, оставляя их парализованными и зависимыми от других нейронов.создать градиент, который может вытащить их из нулевой области.Это вряд ли произойдет, поскольку обучение прогрессирует, все нейроны выталкиваются в нулевую область ReLu.

Есть несколько решений этой проблемы.Вы можете использовать Leaky_relu или другие функции активации, которые не имеют нулевой регион.

Вы также можете нормализовать входные данные с помощью нормализации партии и инициализировать веса, чтобы они имели только положительный вид.

Решение номер два, вероятно, является наиболее оптимальным решением, поскольку оба решат проблему, но leaky_relu продлит обучение, а нормализация партии сделает обратное и повысит точность.С другой стороны, Leaky_relu легко исправить, тогда как другое решение требует немного дополнительной работы.

Для данных Хаунсфилда можно также добавить константу 1000 к входу, исключив отрицательные единицы из данных.Это все еще требует иной инициализации веса, чем стандартная инициализация Pytorch.

0 голосов
/ 25 ноября 2018

Я не думаю, что веса должны быть напечатаны командой, которую вы используете.Попробуйте print(unet.conv1.state_dict()["weight"]) вместо print(unet.conv1.weight.data[0][0]).

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...