См. Фрагмент кода:
import torch
x = torch.tensor([-1.], requires_grad=True)
y = torch.where(x > 0., x, torch.tensor([2.], requires_grad=True))
y.backward()
print(x.grad)
Выход tensor([0.])
, но
import torch
x = torch.tensor([-1.], requires_grad=True)
if x > 0.:
y = x
else:
y = torch.tensor([2.], requires_grad=True)
y.backward()
print(x.grad)
Выход None
.
Я перепутал, почему вывод torch.where
равен tensor([0.])
?
update
import torch
a = torch.tensor([[1,2.], [3., 4]])
b = torch.tensor([-1., -1], requires_grad=True)
a[:,0] = b
(a[0, 0] * a[0, 1]).backward()
print(b.grad)
Выход tensor([2., 0.])
. (a[0, 0] * a[0, 1])
никак не связан с b[1]
, но градиент b[1]
равен 0
, а не None
.