Объяснение следующих результатов Pytorch - PullRequest
1 голос
/ 21 июня 2020

Я пытаюсь получить более глубокое представление о том, как работает автоград Pytorch. Я не могу объяснить следующие результаты:

import torch
def fn(a):
 b = torch.tensor(5,dtype=torch.float32,requires_grad=True)
 return a*b 

a  = torch.tensor(10,dtype=torch.float32,requires_grad=True)
output = fn(a)
output.backward()
print(a.grad)

Результат - тензор (5.). Но мой вопрос в том, что переменная b создается внутри функции и поэтому должна быть удалена из памяти после того, как функция вернет a * b, верно? Итак, когда я звоню в обратном порядке, как значение b все еще присутствует для разрешения этого вычисления? Насколько я понимаю, каждая операция в Pytorch имеет контекстную переменную, которая отслеживает «какой» тензор использовать для обратных вычислений, и также есть версии, присутствующие в каждом тензоре, и если версия изменяется, то обратное должно вызывать ошибку, верно?

Теперь, когда я пытаюсь запустить следующий код,

import torch
def fn(a):
 b = a**2
 for i in range(5):
   b *= b
 return b 

a  = torch.tensor(10,dtype=torch.float32,requires_grad=True)
output = fn(a)
output.backward()
print(a.grad)

, я получаю следующую ошибку: одна из переменных, необходимых для вычисления градиента, была изменена операцией на месте: [torch.FloatTensor [] ], который является выходом 0 MulBackward0, находится в версии 5; вместо этого ожидается версия 4. Подсказка: включите обнаружение аномалий, чтобы найти операцию, при которой не удалось вычислить градиент, с помощью torch.autograd.set_detect_anomaly (True).

Но если я запустил следующий код, ошибки не будет:

import torch
def fn(a):
  b = a**2
  for i in range(2):
    b = b*b
  return b

def fn2(a):
  b = a**2
  c = a**2
  for i in range(2):
    c *= b
  return c

a  = torch.tensor(5,dtype=torch.float32,requires_grad=True)
output = fn(a)
output.backward()
print(a.grad)
output2 = fn2(a)
output2.backward()
print(a.grad)

Результатом этого является:

тензор (625000.)

тензор (643750.)

Итак, для стандартных графов вычислений с довольно большим количеством переменных, в той же функции я могу понять, как работает граф вычислений. Но когда перед вызовом обратной функции изменяется переменная, у меня возникают большие проблемы с пониманием результатов. Может кто-нибудь объяснить?

1 Ответ

3 голосов
/ 21 июня 2020

Обратите внимание, что b *=b не то же самое, что b = b*b.

Возможно, это сбивает с толку, но основные операции различаются.

В случае b *=b in- выполняется операция place, которая приводит к путанице с градиентами и, следовательно, с RuntimeError.

В случае b = b*b, два тензорных объекта умножаются, и результирующему объекту присваивается имя b. Таким образом, нет RuntimeError, когда вы запускаете таким образом.

Вот вопрос SO по базовой операции python: Разница между x + = y и x = x + y

В чем разница между fn в первом случае и fn2 во втором? Операция c*=b не уничтожает связи графика с b из c. Операция c*=c сделала бы невозможным создание графа, соединяющего два тензора с помощью операции.

Ну, я не могу работать с тензорами, чтобы продемонстрировать это, потому что они вызывают RuntimeError. Поэтому я попробую со списком python.

>>> x = [1,2]
>>> y = [3]
>>> id(x), id(y)
(140192646516680, 140192646927112)
>>>
>>> x += y
>>> x, y
([1, 2, 3], [3])
>>> id(x), id(y)
(140192646516680, 140192646927112)

Обратите внимание, что новый объект не создан. Таким образом, невозможно отследить от output до исходных переменных. Мы не можем различать guish и object_140192646516680 как выход или вход. Итак, как создать график с этим ..

Рассмотрим следующий альтернативный случай:

>>> a = [1,2]
>>> b = [3]
>>>
>>> id(a), id(b)
(140192666168008, 140192666168264)
>>>
>>> a = a + b
>>> a, b
([1, 2, 3], [3])
>>> id(a), id(b)
(140192666168328, 140192666168264)
>>>

Обратите внимание, что новый список a на самом деле является новым объектом с id 140192666168328. Здесь мы можем проследить, что object_140192666168328 пришел из addition operation между двумя другими объектами object_140192666168008 и object_140192666168264. Таким образом, можно динамически создавать график и распространять градиенты с output на предыдущие слои.

...