Алгоритм реализации Conv2d в PyTorch - PullRequest
1 голос
/ 24 января 2020

Я работаю над моделью логического вывода модели pytorch onnx, поэтому задаю этот вопрос.

Предположим, у меня есть изображение с размерами 32 x 32 x 3 (набор данных CIFAR-10). Я передаю его через Conv2d с размерами: 3 x 192 x 5 x 5. Я использовал команду: Conv2d(3, 192, kernel_size=5, stride=1, padding=2)

Используя формулу (указанную здесь для справки pg12 https://arxiv.org/pdf/1603.07285.pdf), я должен получить выходное изображение с размерами 28 x 28 x 192 (input - kernel + 1 = 32 - 5 + 1 ).

Вопрос в том, как PyTorch реализовал этот 4-мерный тензор 3 x 192 x 5 x 5, чтобы получить вывод 28 x 28 x 192? Слой представляет собой 4-мерный тензор, а входное изображение - 2-мерное.

Как ядро ​​(5x5) распространяется в матрице изображения 32 x 32 x 3? Что ядро ​​сворачивает с первым -> 3 x 192 или 32 x 32?

Примечание : я понял 2d аспекты вещей. Я задаю вышеуказанные вопросы в 3 или более раз.

1 Ответ

1 голос
/ 24 января 2020

Входные данные для Conv2d представляют собой тензор формы (N, C_in, H_in, W_in), а выходные данные имеют форму (N, C_out, H_out, W_out), где N - размер пакета (количество изображений), C - это количество каналов, H - высота, W - ширина. Высота и ширина вывода H_out, W_out вычисляются следующим образом (без учета расширения):

H_out = (H_in + 2*padding[0] - kernel_size[0]) / stride[0] + 1
W_out = (W_in + 2*padding[1] - kernel_size[1]) / stride[1] + 1

См. cs231n для объяснения того, как были получены эти формулы.

В вашем примере N=1, H_in = 32, W_in = 32, C_in = 3, kernel_size = (5, 5), strides = (1, 1), padding = (0, 0), давая H_out = 28, W_out = 28.

C_out=192 означает, что имеется 192 различных фильтра, каждый из которых имеет форму (C_in, kernel_size[0], kernel_size[1]) = (3, 5, 5). Каждый фильтр независимо выполняет свертку с входным изображением, в результате чего получается 2D-тензор формы (H_out, W_out) = (28, 28), а поскольку имеется C_out = 192 фильтров и N = 1 изображений, конечный результат имеет форму (N, C_out, H_out, W_out) = (1, 192, 28, 28).

Чтобы понять, как именно выполняется свертка, см. Демонстрационную версию свертки .

enter image description here

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...