Делает ли pytorch готовое сокращение своего вычислительного графа? - PullRequest
0 голосов
/ 20 февраля 2019

Это очень простой пример:

import torch

x = torch.tensor([1., 2., 3., 4., 5.], requires_grad=True)
y = torch.tensor([2., 2., 2., 2., 2.], requires_grad=True)
z = torch.tensor([1., 1., 0., 0., 0.], requires_grad=True)

s = torch.sum(x * y * z)
s.backward()

print(x.grad)

Будет напечатано,

tensor([2., 2., 0., 0., 0.]),

, поскольку, конечно, ds / dx равно нулю для записей, где z равно нулю.

Мой вопрос: является ли Pytorch умным и останавливает вычисления, когда он достигает нуля?Или на самом деле выполняет вычисление "2*5", только чтобы потом сделать "10 * 0 = 0"?

В этом простом примере это не имеет большого значения, но в (большей) проблеме ясмотрю, это будет иметь значение.

Спасибо за любой вклад.

1 Ответ

0 голосов
/ 20 февраля 2019

Нет, pytorch не делает такой вещи, чтобы обрезать любые последующие вычисления при достижении нуля.Хуже того, из-за того, как работает арифметика с плавающей запятой, все последующее умножение на ноль займет примерно столько же времени, сколько и любое обычное умножение.

Хотя в некоторых случаях есть способы обойти это, например, если вы хотите использовать замаскированныйПотеря вы можете просто установить маскированные выходы равными нулю или отделить их от градиентов.

Этот пример проясняет разницу:

def time_backward(do_detach):
    x = torch.tensor(torch.rand(100000000), requires_grad=True)
    y = torch.tensor(torch.rand(100000000), requires_grad=True)
    s2 = torch.sum(x * y)
    s1 = torch.sum(x * y)
    if do_detach:
        s2 = s2.detach()
    s = s1 + 0 * s2
    t = time.time()
    s.backward()
    print(time.time() - t)

time_backward(do_detach= False)
time_backward(do_detach= True)

выходы:

0.502875089645
0.198422908783
...