Кажется, вы хотите использовать torch.cat()
(конкатенировать тензоры по существующему измерению), а не torch.stack()
(конкатенировать / складывать тензоры по новому измерению):
import torch
a = torch.randn(1, 42, 1, 1)
b = torch.randn(1, 42, 1, 1)
ab = torch.stack((a, b), 0)
print(ab.shape)
# torch.Size([2, 1, 42, 1, 1])
ab = torch.cat((a, b), 0)
print(ab.shape)
# torch.Size([2, 42, 1, 1])
aab = torch.cat((a, ab), 0)
print(aab.shape)
# torch.Size([3, 42, 1, 1])