Функция PyTorch optimizer.step () не обновляет вес - PullRequest
0 голосов
/ 05 декабря 2018

Код можно увидеть ниже.
Проблема в том, что часть optimizer.step() не работает.Я печатаю model.parameters() до и после тренировки, и веса не меняются.

Я пытаюсь создать персептрон, который может решить проблему AND.Мне удалось это сделать с моей собственной крошечной библиотекой, где я реализовал персептрон с двумя функциями predict() и train().

Просто чтобы прояснить, я только начал изучать глубокое обучение с использованием PyTorch, так что это, вероятно, очень новая проблема.Я пытался найти решение, но без удачи.Я также сравнил свой код с другими работающими кодами, но я не знаю, что я делаю неправильно.

import torch
from torch import nn, optim
from random import randint

class NeuralNet(nn.Module):
  def __init__(self):
    super(NeuralNet, self).__init__()
    self.layer1 = nn.Linear(2, 1)

  def forward(self, input):
    out = input
    out = self.layer1(out)
    out = torch.sign(out)
    out = torch.clamp(out, 0, 1) # 0=false, 1=true
    return out

data = torch.Tensor([[0, 0], [0, 1], [1, 0], [1, 1]])
target = torch.Tensor([0, 0, 0, 1])
model = NeuralNet()
epochs = 1000
lr = 0.01

print(list(model.parameters()))
print() # Print parameters before training
loss_func = nn.L1Loss()
optimizer = optim.Rprop(model.parameters(), lr)
for epoch in range(epochs + 1):
  optimizer.zero_grad()
  rand_int = randint(0, len(data) - 1)
  x = data[rand_int]
  y = target[rand_int]

  pred = model(x)
  loss = loss_func(pred, y)

  loss.backward()
  optimizer.step()

# Print parameters again
# But they haven't changed
print(list(model.parameters()))

1 Ответ

0 голосов
/ 05 декабря 2018

Добро пожаловать в stackoverflow!

Проблема в том, что вы пытаетесь выполнить обратное распространение через недифференцируемую функцию.Не дифференцируемый означает, что никакие градиенты не могут течь обратно через них, подразумевая, что все обучаемые веса, примененные перед ними, не будут обновлены вашим оптимизатором.Такие функции легко обнаружить;это дискретные, острые операции, которые напоминают операторы 'if'.В вашем случае это функция sign().

К сожалению, PyTorch не делает никаких действий в этом отношении и не укажет вам на проблему.Чтобы облегчить проблему, вы могли бы преобразовать диапазон выходного сигнала в [-1,1] и применить нелинейность Tanh() вместо операторов sign() и clamp().

...