Градиент PyTorch отличается от градиента, рассчитанного вручную - PullRequest
0 голосов
/ 13 ноября 2018

Я пытаюсь вычислить градиент 1 / x без использования автограда Pytorch. Я использую формулу grad (1 / x, x) = -1 / x ** 2. Когда я сравниваю свой результат с этой формулой с градиентом, заданным автоградом Pytorch, они отличаются.

Вот мой код:

a = torch.tensor(np.random.randn(), dtype=dtype, requires_grad=True)
loss = 1/a
loss.backward()
print(a.grad - (-1/(a**2)))

Вывод:

tensor(5.9605e-08, grad_fn=<ThAddBackward>)

Может кто-нибудь объяснить мне, в чем проблема?

1 Ответ

0 голосов
/ 13 ноября 2018

Так что я предполагаю, что вы ожидаете ноль в качестве результата. Когда вы смотрите поближе, вы видите, что это довольно близко. При делении чисел в двоичной системе (компьютере) вы часто получаете ошибки округления.

Давайте рассмотрим ваш пример с дополнительным print-Statement добавлено:

a = torch.tensor(np.random.randn(), requires_grad=True)
loss = 1/a
loss.backward()
print(a.grad, (-1/(a**2)))
print(a.grad - (-1/(a**2)))

Из-за случайного ввода вывод, конечно, тоже случайный (поэтому вы не получите эти самые цифры, просто повторите и у вас будут похожие примеры) , иногда вы также получите ноль в качестве результата, здесь обратите внимание на случай:

tensor(-0.9074) tensor(-0.9074, grad_fn=<MulBackward>)
tensor(5.9605e-08, grad_fn=<ThSubBackward>)

Вы видите, хотя оба отображаются как одно и то же число, но они отличаются в одном из последних десятичных знаков. Вот почему вы получаете эту очень маленькую разницу при вычитании обоих.

Эта проблема, как общая проблема компьютеров, в некоторых дробях просто большое или бесконечное количество десятичных знаков, а в вашей памяти - нет. Так что они отрезаны в какой-то момент.

Итак, то, что вы испытываете здесь, на самом деле является недостатком точности. И точность зависит от используемого вами числового типа данных (т. Е. torch.float32 или torch.float64).

Вы также можете посмотреть здесь больше информации:
https://en.wikipedia.org/wiki/Double-precision_floating-point_format


Но это не относится к PyTorch или около того, вот пример Python:

print(29/100*100)

Результаты:

28.999999999999996

Edit:

Как указал @HOANG GIANG, изменение уравнения на - (1 / a) * (1 / a) работает хорошо, и результат равен нулю. Вероятно, это так, потому что вычисление, выполненное для вычисления градиента, очень похоже (или одинаково) на - (1 / a) * (1 / a) в этом случае. Поэтому он имеет те же ошибки округления, поэтому разница равна нулю.

Итак, вот еще один более подходящий пример, чем приведенный выше. Даже если - (1 / x) * (1 / x) математически эквивалентно -1 / x ^ 2 , при вычислении на компьютере это не всегда одинаково, в зависимости от на значение x :

import numpy as np
print('e1 == e2','x value', '\t'*2, 'round-off error', sep='\t')
print('='*70)
for i in range(10):
    x = np.random.randn()
    e1 = -(1/x)*(1/x)
    e2 = (-1/(x**2))
    print(e1 == e2, x, e1-e2, sep='\t\t')

Выход:

e1 == e2    x value                 round-off error
======================================================================
True        0.2934154339948173      0.0
True        -1.2881863891014191     0.0
True        1.0463038021843876      0.0
True        -0.3388766143622498     0.0
True        -0.6915415747192347     0.0
False       1.3299049850551317      1.1102230246251565e-16
True        -1.2392046539563553     0.0
False       -0.42534236747121645    8.881784197001252e-16
True        1.407198823994324       0.0
False       -0.21798652132356966    3.552713678800501e-15


Даже при том, что ошибка округления кажется немного меньше (я пробовал разные случайные значения, и редко более двух из десяти имели ошибку округления ) , но все же уже есть небольшие различия при расчете 1 / x :

import numpy as np
print('e1 == e2','x value', '\t'*2, 'round-off error', sep='\t')
print('='*70)
for i in range(10):
    x = np.random.randn()
    # calculate 1/x
    result = 1/x
    # apply inverse function
    reconstructed_x = 1/result
    # mathematically this should be the same as x
    print(x == reconstructed_x, x, x-reconstructed_x, sep='\t\t')

Выход:

e1 == e2    x value             round-off error
======================================================================
False       0.9382823115235075      1.1102230246251565e-16
True        -0.5081217386356917     0.0
True        -0.04229436058156134    0.0
True        1.1121100294357302      0.0
False       0.4974618312372863      -5.551115123125783e-17
True        -0.20409933212316553    0.0
True        -0.6501652554924282     0.0
True        -3.048057937738731      0.0
True        1.6236075700470816      0.0
True        0.4936926651641918      0.0
...