В чем разница между оператором if-else и torch.where в pytorch? - PullRequest
3 голосов
/ 13 апреля 2020

См. Фрагмент кода:

import torch
x = torch.tensor([-1.], requires_grad=True)
y = torch.where(x > 0., x, torch.tensor([2.], requires_grad=True))
y.backward()
print(x.grad)

Выход tensor([0.]), но

import torch
x = torch.tensor([-1.], requires_grad=True)
if x > 0.:
    y = x
else:
    y = torch.tensor([2.], requires_grad=True)
y.backward()
print(x.grad)

Выход None.

Я перепутал, почему вывод torch.where равен tensor([0.])?

update

import torch
a = torch.tensor([[1,2.], [3., 4]])
b = torch.tensor([-1., -1], requires_grad=True)
a[:,0] = b

(a[0, 0] * a[0, 1]).backward()
print(b.grad)

Выход tensor([2., 0.]). (a[0, 0] * a[0, 1]) никак не связан с b[1], но градиент b[1] равен 0, а не None.

1 Ответ

2 голосов
/ 13 апреля 2020

AD на основе отслеживания, как и pytorch, работает по tracking . Вы не можете отслеживать вещи, которые не являются вызовами функций, перехваченными библиотекой. Используя оператор if, подобный этому, нет никакой связи между x и y, тогда как с where, x и y связаны в дереве выражений.

Теперь, для разностей:

  • В первом фрагменте 0 является правильной производной функции x ↦ x > 0 ? x : 2 в точке -1 (поскольку отрицательная сторона постоянна).
  • Во втором фрагменте, как я уже сказал, x никак не связан с y (в ветви else). Следовательно, производная от y с учетом x не определена, что представляется как None.

(Вы можете делать такие вещи даже в Python, но для этого требуется более сложная технология как преобразование источника. Я не думаю, что это возможно с pytorch.)

...