Почему PyTorch требует сохранения графика? - PullRequest
0 голосов
/ 28 сентября 2018

Я тренирую свою модель следующим образом:

for i in range(5):
  optimizer.zero_grad()
  y = next_input()
  loss = model(y)
  loss.backward()
  optimizer.step()

и получаю эту ошибку

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

Почему мне требуется сохранить график?Он может просто пересчитать производные, если они освобождены.Чтобы доказать это, рассмотрим код:

for i in range(5):
  optimizer.zero_grad()
  model.zero_grad() # drop derivatives
  y = next_input()
  loss = model(y)
  loss.backward(retain_graph=True)
  optimizer.step()

В этом случае производные от предыдущей итерации также обнуляются, но Torch это не волнует, поскольку установлен флаг retain_graph=True.

Прав ли я, что model.zero_grad() отменяет эффект retain_graph=True (т.е. удаляет оставшиеся производные)?

Ответы [ 2 ]

0 голосов
/ 02 октября 2018

Поскольку рассматриваемые градиенты являются градиентами модели, правильный код должен быть model.zero_grad().Я не уверен, что optimizer.zero_grad() сработает, потому что я никогда не пробовал.Ваш первый пример будет:

for i in range(5):
  model.zero_grad()  # instead of optimizer.zero_grad()
  x, y = next_input_output_pair()  # We get both input and expected output
  loss = mean_squared_error(model(x), y)  # the loss is calculated
  loss.backward()  # backward calculation
  optimizer.step()
0 голосов
/ 28 сентября 2018

Вам нужно обнулять градиенты после каждого прохода, чтобы дать факелу команду выбросить ранее накопленные значения, в противном случае, каждый раз, когда вы вызываете loss.backward(), он будет проходить в обратном направлении через все вычисления.

Так что код долженbe

for i in range(5):
  y = next_input()
  loss = model(y)
  loss.backward()
  optimizer.step()
  optimizer.zero_grad()

Поскольку вы не обнуляете градиенты, pytorch пытается выполнить обратное распространение в ходе предыдущих вычислений, поэтому выдает ошибку о сохранении графа.Если вы сохраняете график, по сути вы не отбрасываете накопленные градиенты предыдущих шагов.

Это обсуждение на форуме pytorch может быть полезно для вас.Он подчеркивает то же, что я упоминал выше

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...