Как найти максимальный индекс для каждой строки в тензорном объекте? - PullRequest
0 голосов
/ 14 февраля 2019

Итак, я создаю модель pytorch и для прямого прохода применяю метод прямого прохода, чтобы получить тензор баллов, который содержит баллы прогноза для каждого класса.Форма этого тензора [100, 10].Теперь я хочу получить точность, сравнивая ее с y, который содержит фактические оценки.Этот тензор имеет форму [100].Для сравнения я буду использовать torch.mean(scores == y) и посчитаю, сколько их одинаково.

Проблема в том, что мне нужно преобразовать тензор оценок, чтобы каждая строка просто содержала индекс наибольшего значения в каждой строке.Например, если тензор выглядел следующим образом,

tensor(
    [[0.3232, -0.2321, 0.2332, -0.1231, 0.2435, 0.6728],

    [0.2323, -0.1231, -0.5321, -0.1452, 0.5435, 0.1722],

    [0.9823, -0.1321, -0.6433, 0.1231, 0.023, 0.0711]]
)

Тогда я бы хотел, чтобы он был преобразован так, чтобы он выглядел следующим образом.

tensor([5, 4, 0])

Как я мог это сделать?

1 Ответ

0 голосов
/ 14 февраля 2019

Используйте argmax с желаемым dim (он же ось)

a = tensor(
    [[0.3232, -0.2321, 0.2332, -0.1231, 0.2435, 0.6728],
    [0.2323, -0.1231, -0.5321, -0.1452, 0.5435, 0.1722],
    [0.9823, -0.1321, -0.6433, 0.1231, 0.023, 0.0711]]
)

a.argmax(1)
# tensor([ 5,  4,  0])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...