Другой подход, который решит вашу проблему:
orig_shape = (100, 1024, 14, 14)
new_shape = (100, 10, 1024, 14, 14)
input = torch.randn(orig_shape) # [100, 1024, 14, 14]
input = input.unsqueeze(1) # [100, 1, 1024, 14, 14]
input = input.expand(*new_shape) # [100, 10, 1024, 14, 14]
input = input.transpose(0, 1).contiguous() # [10, 100, 1024, 14, 14]
input = input.view(-1, *orig_shape[1:]) # [1000, 1024, 14, 14]
Мы можем это проверить.
orig_shape = (2, 3, 4)
new_shape = (2, 5, 3, 4)
input = torch.randn(orig_shape)
print(input)
input = input.unsqueeze(1)
input = input.expand(*new_shape)
input = input.transpose(0, 1).contiguous()
input = input.view(-1, *orig_shape[1:])
print(input)
Фрагмент кода приводит к:
tensor([[[-1.1728, 1.0421, -1.0716, 0.6456],
[-1.2214, 1.1484, -0.1436, 1.2353],
[-0.4395, -0.9473, -0.1382, -0.9357]],
[[-0.4735, -1.4329, -0.0025, -0.6384],
[ 0.5102, 0.7813, 1.2810, -0.6013],
[ 0.6152, 1.1734, -0.4591, -1.7447]]])
tensor([[[-1.1728, 1.0421, -1.0716, 0.6456],
[-1.2214, 1.1484, -0.1436, 1.2353],
[-0.4395, -0.9473, -0.1382, -0.9357]],
[[-0.4735, -1.4329, -0.0025, -0.6384],
[ 0.5102, 0.7813, 1.2810, -0.6013],
[ 0.6152, 1.1734, -0.4591, -1.7447]],
[[-1.1728, 1.0421, -1.0716, 0.6456],
[-1.2214, 1.1484, -0.1436, 1.2353],
[-0.4395, -0.9473, -0.1382, -0.9357]],
[[-0.4735, -1.4329, -0.0025, -0.6384],
[ 0.5102, 0.7813, 1.2810, -0.6013],
[ 0.6152, 1.1734, -0.4591, -1.7447]]])