Pytorch: Память Эффективная взвешенная сумма с весами, распределенными по каналам - PullRequest
0 голосов
/ 04 мая 2020

Входы:

1) I = Тензор дим (N, C, X) (Вход)

2) W = Тензор дим (N, X, Y) (Вес)

Выход:

1) O = Тензор дим (N, C, Y) (Выход)

Я хочу вычислить:

I = I.view(N, C, X, 1)
W = W.view(N, 1, X, Y)
PROD = I*W
O = PROD.sum(dim=2)
return O

без наложения N * C * X * Y накладные расходы памяти .

В основном я хочу вычислить взвешенную сумму карты объектов, где веса одинаковы по измерению канала, без дополнительных затрат памяти на канал.


Возможно, я мог бы использовать

from itertools import product

O = torch.zeros(N, C, Y)
for n, x, y in product(range(N), range(X), range(Y)):
    O[n, :, y] += I[n, :, x]*W[n, x, y]
return O

, но это было бы медленнее (без вещания), и я не уверен, как из-за сохранения переменных для обратного прохода может возникнуть много дополнительной памяти.

1 Ответ

2 голосов
/ 04 мая 2020

Вы можете использовать torch.bmm (https://pytorch.org/docs/stable/torch.html#torch .bmm ). Просто выполните torch.bmm(I,W)

Чтобы проверить результаты:

import torch
N, C, X, Y= 100, 10, 9, 8 

i = torch.rand(N,C,X)
w = torch.rand(N,X,Y)
o = torch.bmm(i,w)

# desired result code
I = i.view(N, C, X, 1)
W = w.view(N, 1, X, Y)
PROD = I*W
O = PROD.sum(dim=2)

print(torch.allclose(O,o)) # should output True if outputs are same.

РЕДАКТИРОВАТЬ: В идеале, я бы предположил, что использование умножения внутренней матрицы Pytorch является эффективным. Однако вы также можете измерить использование памяти с помощью tracemalloc (по крайней мере, на процессоре). См. https://discuss.pytorch.org/t/measuring-peak-memory-usage-tracemalloc-for-pytorch/34067 для графического процессора.

import torch
import tracemalloc
tracemalloc.start()
N, C, X, Y= 100, 10, 9, 8 

i = torch.rand(N,C,X)
w = torch.rand(N,X,Y)
o = torch.bmm(i,w)
# output is a tuple indicating current memory and peak memory
print(tracemalloc.get_traced_memory())  

Вы можете сделать то же самое с другим кодом и увидеть, что реализация bmm действительно эффективна.

import torch
import tracemalloc
tracemalloc.start()
N, C, X, Y= 100, 10, 9, 8 

i = torch.rand(N,C,X)
w = torch.rand(N,X,Y)

I = i.view(N, C, X, 1)
W = w.view(N, 1, X, Y)
PROD = I*W
O = PROD.sum(dim=2)
# output is a tuple indicating current memory and peak memory
print(tracemalloc.get_traced_memory())  

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...