Как выполняется Пакетное обучение в Pytorch? - PullRequest
1 голос
/ 19 июня 2019

Когда вы посмотрите на то, как строится сетевая архитектура внутри кода Pytorch, нам нужно расширить torch.nn.Module и внутри __init__, мы определяем модуль сетей, и Pytorch собирается отслеживать градиенты параметров этих модулей.,Затем внутри функции forward мы определяем, как должен выполняться прямой проход для нашей сети.

Здесь я не понимаю, как происходит пакетное обучение.Ни в одном из приведенных выше определений, включая функцию forward, нас не волнует размер пакета входных данных для нашей сети.Единственное, что нам нужно настроить для выполнения пакетного обучения, - это добавить к входу дополнительное измерение, соответствующее размеру пакета, но ничего внутри определения сети не изменится, если мы будем работать с пакетным обучением.По крайней мере, это то, что я видел в кодах здесь .

Итак, если все, что я объяснил до сих пор, верно (я был бы очень признателен, если бы вы дали мне знатьесли я что-то неправильно понял), как выполняется пакетное обучение, если ничего не объявлено относительно размера пакета внутри определения нашего сетевого класса (класса, который наследует torch.nn.Module)?В частности, мне интересно узнать, как алгоритм пакетного градиентного спуска реализован в pytorch, когда мы просто установили nn.MSELoss с размерностью пакета.

1 Ответ

1 голос
/ 23 июня 2019

Проверьте это:

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super().__init__()         

    def forward(self, x):
        print("Hi ma")        
        print(x)
        x = F.relu(x)
        return x

n = Net()
r = n(torch.tensor(-1))
print(r)
r = n.forward(torch.tensor(1)) #not planned to call directly
print(r)

out:

Hi ma
tensor(-1)
tensor(0)
Hi ma
tensor(1)
tensor(1)

Следует помнить, что forward не следует вызывать напрямую.PyTorch сделал этот объект модуля n вызываемым.Они реализовали callable как:

 def __call__(self, *input, **kwargs):
    for hook in self._forward_pre_hooks.values():
        hook(self, input)
    if torch._C._get_tracing_state():
        result = self._slow_forward(*input, **kwargs)
    else:
        result = self.forward(*input, **kwargs)
    for hook in self._forward_hooks.values():
        hook_result = hook(self, input, result)
        if hook_result is not None:
            raise RuntimeError(
                "forward hooks should never return any values, but '{}'"
                "didn't return None".format(hook))
    if len(self._backward_hooks) > 0:
        var = result
        while not isinstance(var, torch.Tensor):
            if isinstance(var, dict):
                var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
            else:
                var = var[0]
        grad_fn = var.grad_fn
        if grad_fn is not None:
            for hook in self._backward_hooks.values():
                wrapper = functools.partial(hook, self)
                functools.update_wrapper(wrapper, hook)
                grad_fn.register_hook(wrapper)
    return result

И просто n() автоматически вызовет forward.

В общем, __init__ определяет структуру модуля и forward() определяет операции для одного пакета.

Эта операция может повторяться при необходимости для некоторых элементов структуры, или вы можете вызывать функции для тензоров напрямую, как мы это делали x = F.relu(x).

Вы получили это великолепновсе в PyTorch будет выполняться партиями (мини-пакетами), поскольку PyTorch оптимизирован для такой работы.

Это означает, что при чтении изображения вы будете читать не один, а один bs Пакеты изображений.

...