Хотя решение Берриэля решает этот конкретный вопрос c, я подумал, что добавление некоторого объяснения может помочь всем пролить свет на трюк, который здесь используется, так что его можно адаптировать для (m) любые другие измерения.
Давайте начнем с проверки формы входного тензора x
:
In [58]: x.shape
Out[58]: torch.Size([3, 2, 2])
Итак, у нас есть 3D-тензор формы (3, 2, 2)
. Теперь, в соответствии с вопросом OP, нам нужно вычислить maximum
значений в тензоре как по 1 st , так и по 2 nd измерениям. На момент написания статьи аргумент torch.max()
dim
поддерживает только int
. Итак, мы не можем использовать кортеж. Итак, мы будем использовать следующий трюк, который я назову,
Трюк Flatten & Max : так как мы хотим вычислить max
по обоим 1 st и размеры 2 и , мы сведем оба эти измерения в одно измерение и оставим измерение 0 th без изменений. Это именно то, что происходит, делая:
In [61]: x.flatten().reshape(x.shape[0], -1).shape
Out[61]: torch.Size([3, 4]) # 2*2 = 4
Итак, теперь мы сжали 3D-тензор до 2D-тензора (то есть матрицы).
In [62]: x.flatten().reshape(x.shape[0], -1)
Out[62]:
tensor([[-0.3000, -0.2926, -0.2705, -0.2632],
[-0.1821, -0.1747, -0.1526, -0.1453],
[-0.0642, -0.0568, -0.0347, -0.0274]])
Теперь мы можем просто применить max
к измерению 1 st (т.е. в этом случае первое измерение также является последним измерением), поскольку сглаженные измерения находятся в этом измерении.
In [65]: x.flatten().reshape(x.shape[0], -1).max(dim=1) # or: `dim = -1`
Out[65]:
torch.return_types.max(
values=tensor([-0.2632, -0.1453, -0.0274]),
indices=tensor([3, 3, 3]))
Мы получили 3 значения в результирующем тензоре, поскольку у нас было 3 строки в матрице.
Теперь, с другой стороны, если вы хотите вычислить max
по 0 th и 1 st размеры, вы бы сделали:
In [80]: x.flatten().reshape(-1, x.shape[-1]).shape
Out[80]: torch.Size([6, 2]) # 3*2 = 6
In [79]: x.flatten().reshape(-1, x.shape[-1])
Out[79]:
tensor([[-0.3000, -0.2926],
[-0.2705, -0.2632],
[-0.1821, -0.1747],
[-0.1526, -0.1453],
[-0.0642, -0.0568],
[-0.0347, -0.0274]])
Теперь мы можем просто применить max
к измерению 0 th , так как это результат нашего сглаживания , ((также из нашей первоначальной формы (3, 2, 2
), после получения максимума по первым двум измерениям, мы должны получить два значения как результат.)
In [82]: x.flatten().reshape(-1, x.shape[-1]).max(dim=0)
Out[82]:
torch.return_types.max(
values=tensor([-0.0347, -0.0274]),
indices=tensor([5, 5]))
Аналогичным образом вы можете адаптировать это подход к нескольким измерениям и другим функциям сокращения, таким как min
.
Примечание : я следую терминологии измерений на основе 0 (0, 1, 2, 3, ...
) просто для того, чтобы в соответствии с использованием PyTorch и кода.