PyTorch torch.max в нескольких измерениях - PullRequest
5 голосов
/ 11 апреля 2020

Имеют тензор типа: x.shape = [3, 2, 2].

import torch

x = torch.tensor([
    [[-0.3000, -0.2926],[-0.2705, -0.2632]],
    [[-0.1821, -0.1747],[-0.1526, -0.1453]],
    [[-0.0642, -0.0568],[-0.0347, -0.0274]]
])

Мне нужно взять .max() по 2-му и 3-му измерениям. Я ожидаю что-то подобное [-0.2632, -0.1453, -0.0274] в качестве вывода. Я пытался использовать: x.max(dim=(1,2)), но это вызывает ошибку.

Ответы [ 2 ]

5 голосов
/ 11 апреля 2020

На сегодняшний день в PyTorch нет способа сделать .min() или .max() для нескольких измерений. Об этом есть открытый выпуск , за которым вы можете следить и смотреть, будет ли он когда-либо реализован. Обходной путь в вашем случае:

import torch

x = torch.tensor([
    [[-0.3000, -0.2926],[-0.2705, -0.2632]],
    [[-0.1821, -0.1747],[-0.1526, -0.1453]],
    [[-0.0642, -0.0568],[-0.0347, -0.0274]]
])

print(x.view(x.size(0), -1).max(dim=-1))

# output:
# >>> values=tensor([-0.2632, -0.1453, -0.0274]),
# >>> indices=tensor([3, 3, 3]))

Итак, если вам нужны только значения: x.view(x.size(0), -1).max(dim=-1).values.

Если x не является непрерывным тензором, то .view() не удастся. В этом случае вам следует использовать .reshape().

3 голосов
/ 11 апреля 2020

Хотя решение Берриэля решает этот конкретный вопрос c, я подумал, что добавление некоторого объяснения может помочь всем пролить свет на трюк, который здесь используется, так что его можно адаптировать для (m) любые другие измерения.

Давайте начнем с проверки формы входного тензора x:

In [58]: x.shape   
Out[58]: torch.Size([3, 2, 2])

Итак, у нас есть 3D-тензор формы (3, 2, 2). Теперь, в соответствии с вопросом OP, нам нужно вычислить maximum значений в тензоре как по 1 st , так и по 2 nd измерениям. На момент написания статьи аргумент torch.max() dim поддерживает только int. Итак, мы не можем использовать кортеж. Итак, мы будем использовать следующий трюк, который я назову,

Трюк Flatten & Max : так как мы хотим вычислить max по обоим 1 st и размеры 2 и , мы сведем оба эти измерения в одно измерение и оставим измерение 0 th без изменений. Это именно то, что происходит, делая:

In [61]: x.flatten().reshape(x.shape[0], -1).shape   
Out[61]: torch.Size([3, 4])   # 2*2 = 4

Итак, теперь мы сжали 3D-тензор до 2D-тензора (то есть матрицы).

In [62]: x.flatten().reshape(x.shape[0], -1) 
Out[62]:
tensor([[-0.3000, -0.2926, -0.2705, -0.2632],
        [-0.1821, -0.1747, -0.1526, -0.1453],
        [-0.0642, -0.0568, -0.0347, -0.0274]])

Теперь мы можем просто применить max к измерению 1 st (т.е. в этом случае первое измерение также является последним измерением), поскольку сглаженные измерения находятся в этом измерении.

In [65]: x.flatten().reshape(x.shape[0], -1).max(dim=1)    # or: `dim = -1`
Out[65]: 
torch.return_types.max(
values=tensor([-0.2632, -0.1453, -0.0274]),
indices=tensor([3, 3, 3]))

Мы получили 3 значения в результирующем тензоре, поскольку у нас было 3 строки в матрице.


Теперь, с другой стороны, если вы хотите вычислить max по 0 th и 1 st размеры, вы бы сделали:

In [80]: x.flatten().reshape(-1, x.shape[-1]).shape 
Out[80]: torch.Size([6, 2])    # 3*2 = 6

In [79]: x.flatten().reshape(-1, x.shape[-1]) 
Out[79]: 
tensor([[-0.3000, -0.2926],
        [-0.2705, -0.2632],
        [-0.1821, -0.1747],
        [-0.1526, -0.1453],
        [-0.0642, -0.0568],
        [-0.0347, -0.0274]])

Теперь мы можем просто применить max к измерению 0 th , так как это результат нашего сглаживания , ((также из нашей первоначальной формы (3, 2, 2), после получения максимума по первым двум измерениям, мы должны получить два значения как результат.)

In [82]: x.flatten().reshape(-1, x.shape[-1]).max(dim=0) 
Out[82]: 
torch.return_types.max(
values=tensor([-0.0347, -0.0274]),
indices=tensor([5, 5]))

Аналогичным образом вы можете адаптировать это подход к нескольким измерениям и другим функциям сокращения, таким как min.


Примечание : я следую терминологии измерений на основе 0 (0, 1, 2, 3, ...) просто для того, чтобы в соответствии с использованием PyTorch и кода.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...