PyTorch set_grad_enabled (False) против no_grad (): - PullRequest
0 голосов
/ 23 ноября 2018

Предполагая, что autograd включен (как по умолчанию), есть ли разница (кроме отступа) между выполнением:

with torch.no_grad():
    <code>

и

torch.set_grad_enabled(False)
<code>
torch.set_grad_enabled(True)

Ответы [ 2 ]

0 голосов
/ 11 апреля 2019

Документация torch.autograd.enable_grad гласит:

Включает расчет градиента в контексте no_grad.Это не имеет никакого эффекта за пределами no_grad.

Учитывая эту формулировку, ожидается следующее:

torch.set_grad_enabled(False)
with torch.enable_grad:
    # Gradient tracking will NOT be enabled here.
torch.set_grad_enabled(True)

против:

with torch.no_grad():
    with torch.enable_grad:
        # Gradient tracking IS enabled here.

Но как сине-феникс показывает , это не дело.

Я поднял вопрос здесь .

0 голосов
/ 23 ноября 2018

На самом деле нет, нет разницы в способе, используемом в вопросе.Когда вы посмотрите на исходный код no_grad.Вы видите, что он на самом деле использует torch.set_grad_enabled для архивирования этого поведения:

class no_grad(object):
    r"""Context-manager that disabled gradient calculation.

    Disabling gradient calculation is useful for inference, when you are sure
    that you will not call :meth:`Tensor.backward()`. It will reduce memory
    consumption for computations that would otherwise have `requires_grad=True`.
    In this mode, the result of every computation will have
    `requires_grad=False`, even when the inputs have `requires_grad=True`.

    Also functions as a decorator.


    Example::

        >>> x = torch.tensor([1], requires_grad=True)
        >>> with torch.no_grad():
        ...   y = x * 2
        >>> y.requires_grad
        False
        >>> @torch.no_grad()
        ... def doubler(x):
        ...     return x * 2
        >>> z = doubler(x)
        >>> z.requires_grad
        False
    """

    def __init__(self):
        self.prev = torch.is_grad_enabled()

    def __enter__(self):
        torch._C.set_grad_enabled(False)

    def __exit__(self, *args):
        torch.set_grad_enabled(self.prev)
        return False

    def __call__(self, func):
        @functools.wraps(func)
        def decorate_no_grad(*args, **kwargs):
            with self:
                return func(*args, **kwargs)
        return decorate_no_grad

Однако при использовании в with -статменте есть дополнительная функциональность torch.set_grad_enabled сверх torch.no_grad, которая позволяет вамуправление для включения или выключения вычисления градиента:

    >>> x = torch.tensor([1], requires_grad=True)
    >>> is_train = False
    >>> with torch.set_grad_enabled(is_train):
    ...   y = x * 2
    >>> y.requires_grad

https://pytorch.org/docs/stable/_modules/torch/autograd/grad_mode.html


Редактировать:

@ TomHale Относительно вашего комментария.Я только что сделал короткий тест с PyTorch 1.0, и оказалось, что градиент будет активен:

import torch
w = torch.rand(5, requires_grad=True)
print('Grad Before:', w.grad)
torch.set_grad_enabled(False)
with torch.enable_grad():
    scalar = w.sum()
    scalar.backward()
    # Gradient tracking will be enabled here.
torch.set_grad_enabled(True)

print('Grad After:', w.grad)

Вывод:

Grad Before: None
Grad After: tensor([1., 1., 1., 1., 1.])

Таким образом, градиенты будут вычисляться в этом параметре.

Другая настройка, которую вы разместили в своем ответе, также приводит к тому же результату:

import torch
w = torch.rand(5, requires_grad=True)
print('Grad Before:', w.grad)
with torch.no_grad():
    with torch.enable_grad():
        # Gradient tracking IS enabled here.
        scalar = w.sum()
        scalar.backward()

print('Grad After:', w.grad)

Вывод:

Grad Before: None
Grad After: tensor([1., 1., 1., 1., 1.])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...