Извлечение индексов из максимального пула по единым данным - PullRequest
1 голос
/ 30 марта 2020

Я пытаюсь найти максимальное количество точек в 2D-тензоре для данного размера ядра, но у меня возникают проблемы с особым случаем, когда все значения одинаковы. Например, в следующем примере я хотел бы отметить каждую точку как максимальную точку:

+---+---+---+---+
| 5 | 5 | 5 | 5 |
+---+---+---+---+
| 5 | 5 | 5 | 5 |
+---+---+---+---+
| 5 | 5 | 5 | 5 |
+---+---+---+---+
| 5 | 5 | 5 | 5 |
+---+---+---+---+

Если я запускаю torch.nn.functional.max_pool2d с размером ядра = 3, шаг = 1 и заполнение = 1, я получаю следующие признаки:

+---+---+---+----+
| 0 | 0 | 1 |  2 |
+---+---+---+----+
| 0 | 0 | 1 |  2 |
+---+---+---+----+
| 4 | 4 | 5 |  6 |
+---+---+---+----+
| 8 | 8 | 9 | 10 |
+---+---+---+----+

Какие изменения мне нужно учитывать, чтобы вместо этого получить следующие признаки?

+----+----+----+----+
| 1  | 2  | 3  |  4 |
+----+----+----+----+
|  5 |  6 |  7 |  8 |
+----+----+----+----+
|  9 | 10 | 11 | 12 |
+----+----+----+----+
| 13 | 14 | 15 | 16 |
+----+----+----+----+

Ответы [ 2 ]

3 голосов
/ 30 марта 2020

Вы можете сделать следующее:

a = torch.ones(4,4)
indices = (a == torch.max(a).item()).nonzero()

То, что это делает, возвращает тензор размером [16,2] с 2D-координатами максимального значения (ий), то есть [0,0], [0,1], .., [3,3]. Часть torch.max должна быть легкой для понимания, nonzero() считает логический тензор, заданный (a == torch.max(a).item()), принимает False равным 0 и возвращает ненулевые индексы. Надеюсь, это поможет!

0 голосов
/ 30 марта 2020

Если вы хотите, чтобы индексы в 2d форме @ccl дали вам ответ, но для 1d индексов вы можете сначала сделать x 1d, используя тензор torch.flatten, а затем получить индексы с torch.nonzero и, наконец, преобразовать в ту же форму.

x = torch.ones(4,4) * 5

(x.flatten() == x.flatten().max()).nonzero().reshape(x.shape) + 1
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12],
        [13, 14, 15, 16]])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...