Это очень простой пример:
import torch
x = torch.tensor([1., 2., 3., 4., 5.], requires_grad=True)
y = torch.tensor([2., 2., 2., 2., 2.], requires_grad=True)
z = torch.tensor([1., 1., 0., 0., 0.], requires_grad=True)
s = torch.sum(x * y * z)
s.backward()
print(x.grad)
Будет напечатано,
tensor([2., 2., 0., 0., 0.]),
, поскольку, конечно, ds / dx равно нулю для записей, где z равно нулю.
Мой вопрос: является ли Pytorch умным и останавливает вычисления, когда он достигает нуля?Или на самом деле выполняет вычисление "2*5
", только чтобы потом сделать "10 * 0 = 0
"?
В этом простом примере это не имеет большого значения, но в (большей) проблеме ясмотрю, это будет иметь значение.
Спасибо за любой вклад.