Какова цель параметра dim в torch.nn.Softmax - PullRequest
1 голос
/ 21 апреля 2020

Я не понимаю, к чему относится параметр dim в torch.nn.Softmax. Есть предупреждение, которое говорит мне, чтобы использовать его, и я установил его на 1, но я не понимаю, что я устанавливаю. Где это используется в формуле:

Softmax(xi​)=exp(xi)/∑j​exp(xj​)​

Здесь нет тусклости, так к чему это относится?

1 Ответ

2 голосов
/ 21 апреля 2020

Документация Pytorch на torch.nn.Softmax гласит: dim (int) - Измерение, по которому будет вычисляться Softmax (поэтому каждый срез вдоль dim будет суммироваться до 1).

Например, если у вас есть матрица с двумя измерениями, вы можете выбрать, хотите ли вы применить softmax к строкам или столбцам:

import torch 
import numpy as np

softmax0 = torch.nn.Softmax(dim=0) # Applies along columns
softmax1 = torch.nn.Softmax(dim=1) # Applies along rows 

v = np.array([[1,2,3],
              [4,5,6]])
v =  torch.from_numpy(v).float()

softmax0(v)
# Returns
#[[0.0474, 0.0474, 0.0474],
# [0.9526, 0.9526, 0.9526]])


softmax1(v)
# Returns
#[[0.0900, 0.2447, 0.6652],
# [0.0900, 0.2447, 0.6652]]

Обратите внимание, как для softmax0 столбцы добавляют к 1, и для softmax1 строки добавляют к 1.

...