Как эффективно нормализовать группу элементов в тензоре - PullRequest
0 голосов
/ 23 октября 2018

Контекст:

Я пытаюсь воспроизвести «Матричные капсулы Хинтона с EM-маршрутизацией» (https://openreview.net/forum?id=HJWLfGWRb).

. В какой-то момент выполняются сверточные операции (в том смысле, чтовыходной тензор соединен с входным тензором, и на каждый элемент выходного тензора влияют только входные элементы, содержащиеся в 2D-маске размера K).

Ввод Тензор x формы w_in,w_in, где

  • w_in=14

Промежуточный тензор, отображенный на вход Тензор x_mapped формы w_out,w_out,K,K, где

  • K=3 - размер ядра свертки
  • w_out=6, полученный в результате свертки с stride=2

Суммирование по измерениям 2 и 3(оба размера K) означает суммирование на входных элементах, подключенных к выходному элементу, местоположение которого задается размерами 0 и 1.

Вопрос:

Как эффективно нормализовать (чтобы1) группы элементов в x_mapped, основанные на их расположении втензор ввода x?

Например:
x_mapped(0,0,2,2)
x_mapped(1,0,0,2)
x_mapped(0,1,2,0)
x_mapped(1,1,0,0)
все подключены к x(2,2)(формула i_out*stride + K_index = i_in).По этой причине я хотел бы, чтобы сумма этих 4 элементов была равна 1.

И я хотел бы сделать это для всех групп элементов в x_mapped, которые «связаны» с одним и тем же элементом вx.

Я могу понять, как это сделать:

  1. Создание словаря с местом ввода в качестве ключа и списка элементов вывода в качестве значения
  2. Цикл по словарю, суммирование элементов в списке для заданного местоположения ввода и деление их на эту сумму

, но это кажется мне действительно неэффективным.

1 Ответ

0 голосов
/ 24 октября 2018

Я решил это следующим образом:

  1. Создание словаря с 2-кортежем в качестве ключа (координаты в x) и списком элементов x_mapped в качестве значений.
  2. Один цикл над словарем, сжатие всех элементов одного словарного элемента, затем нормализация.

Вот код:

from collections import defaultdict
import torch

ho = 6
wo = 6
stride = 2
K = 3

d = defaultdict(list)

x_mapped = torch.arange(0,ho*wo*K*K).view(ho,wo,K,K).type(dtype = torch.DoubleTensor)

for i_out in range(0,ho):
    for j_out in range(0,wo):
        for K_i in range(0,K):
            for K_j in range(0, K):
                i_in = i_out * stride + K_i
                j_in = j_out * stride + K_j

                d[(i_in, j_in)].append((i_out, j_out, K_i, K_j))

for _ , value in d.items():
    ho_list, wo_list, K_i_list, K_j_list = zip(*value)
    x_mapped[ho_list, wo_list, K_i_list, K_j_list] = x_mapped[ho_list, wo_list, K_i_list, K_j_list] / torch.sum(
        x_mapped[ho_list, wo_list, K_i_list, K_j_list])
...