Как мне объединить 2D-свертки в PyTorch? - PullRequest
2 голосов
/ 12 октября 2019

Из линейной алгебры мы знаем, что линейные операторы являются коммутативными и ассоциативными.

В мире глубокого обучения эта концепция используется для оправдания введения нелинейностей между NN-слоями, явление, которое в разговорной речи называется линейная лазанья , ( ссылка ).

При обработке сигналов это хорошо известный прием для оптимизации требований к памяти и / или времени выполнения ( ссылка ).

Таким образом, объединение сверток является очень полезным инструментом с разных точек зрения. Как реализовать это с PyTorch?

1 Ответ

5 голосов
/ 12 октября 2019

Если у нас есть y = x * a * b (где * означает свертку, а a, b - ваши ядра), мы можем определить c = a * b так, чтобы y = x * c = x * a * b было следующим образом:

import torch

def merge_conv_kernels(k1, k2):
    """
    :input k1: A tensor of shape ``(out1, in1, s1, s1)``
    :input k1: A tensor of shape ``(out2, in2, s2, s2)``
    :returns: A tensor of shape ``(out2, in1, s1+s2-1, s1+s2-1)``
      so that convolving with it equals convolving with k1 and
      then with k2.
    """
    padding = k2.shape[-1] - 1
    # Flip because this is actually correlation, and permute to adapt to BHCW
    k3 = torch.conv2d(k1.permute(1, 0, 2, 3), k2.flip(-1, -2),
                      padding=padding).permute(1, 0, 2, 3)
    return k3

Для иллюстрацииэквивалентность, этот пример объединяет два ядра с параметрами 900 и 5000 соответственно в эквивалентное ядро ​​из 28 параметров:

# Create 2 conv. kernels
out1, in1, s1 = (100, 1, 3)
out2, in2, s2 = (2, 100, 5)
kernel1 = torch.rand(out1, in1, s1, s1, dtype=torch.float64)
kernel2 = torch.rand(out2, in2, s2, s2, dtype=torch.float64)

# propagate a random tensor through them. Note that padding
# corresponds to the "full" mathematical operation (s-1)
b, c, h, w = 1, 1, 6, 6
x = torch.rand(b, c, h, w, dtype=torch.float64) * 10
c1 = torch.conv2d(x, kernel1, padding=s1 - 1)
c2 = torch.conv2d(c1, kernel2, padding=s2 - 1)

# check that the collapsed conv2d is same as c2:
kernel3 = merge_conv_kernels(kernel1, kernel2)
c3 = torch.conv2d(x, kernel3, padding=kernel3.shape[-1] - 1)
print(kernel3.shape)
print((c2 - c3).abs().sum() < 1e-5)

Примечание: Эквивалентность предполагает, чтоу нас есть неограниченное числовое разрешение. Я думаю, что было исследование по суммированию многих линейных операций с низким разрешением и показало, что сети извлекли выгоду из числовой ошибки, но я не могу найти ее. Любая ссылка приветствуется!

...