Pytorch - Размер стека должен быть точно таким же? - PullRequest
0 голосов
/ 17 мая 2018

В pytorch, учитывая тензоры a формы (1X11) и b формы (1X11), torch.stack((a,b),0) даст мне тензор формы (2X11)

Однако, когда a имеет форму (2X11), а b имеет форму (1X11), torch.stack((a,b),0) вызовет ошибку, ср. «Два тензорных размера должны быть точно такими же».

Поскольку эти два тензора являются выходными данными модели ( градиент включен ), я не могу преобразовать их в numpy для использования np.stack() или np.vstack().

Есть ли какое-нибудь возможное решение для минимального использования памяти GPU?

1 Ответ

0 голосов
/ 17 мая 2018

Кажется, вы хотите использовать torch.cat() (конкатенировать тензоры по существующему измерению), а не torch.stack() (конкатенировать / складывать тензоры по новому измерению):

import torch

a = torch.randn(1, 42, 1, 1)
b = torch.randn(1, 42, 1, 1)

ab = torch.stack((a, b), 0)
print(ab.shape)
# torch.Size([2, 1, 42, 1, 1])

ab = torch.cat((a, b), 0)
print(ab.shape)
# torch.Size([2, 42, 1, 1])
aab = torch.cat((a, ab), 0)
print(aab.shape)
# torch.Size([3, 42, 1, 1])
...