Самый простой способ уменьшить количество каналов - использовать ядро 1x1:
import torch
x = torch.rand(1, 512, 50, 50)
conv = torch.nn.Conv2d(512, 3, 1)
y = conv(x)
print(y.size())
# torch.Size([1, 3, 50, 50])
Если вам действительно по какой-то причине необходимо выполнить объединение в пул по измерению каналов, возможно, вы захотите переставить измерениятак что размер каналов поменяется местами с другим измерением (например, шириной).На эту идею ссылались здесь .