Обсуждаемая функция выглядит следующим образом:
def softmax(x):
assert len(x.shape) == 2
m = x.max(1)[:, np.newaxis]
u = np.exp(x - m)
z = u.sum(1)[:, np.newaxis]
return u / z
По мере того, как оператор assert
предлагает, функция softmax должна применяться к массиву 2D ; и чтобы все строки результата u / z
суммировались в одну. Вот почему применяются методы max
и sum
по строкам , , т.е. . с параметром axis=1
.
Broadcasting и np.newaxis
Для каждой строки x[i]
из x
мы хотим вычислить np.exp(x[i]) / np.sum(np.exp(x[i]))
. Здесь термин нормализации np.sum(np.exp(x[i]))
является числом, а термин np.exp(x[i])
является одномерным массивом. Благодаря numpy правилам вещания , операция может быть выполнена.
Теперь итераций в строках x
можно избежать благодаря numpy. Давайте возьмем в качестве примера следующий массив для np.exp(x)
.
u = np.array([[ 9, 6, 13, 19, 8],
[ 2, 17, 18, 0, 13],
[ 8, 3, 2, 18, 10]]) # np.exp(x)
u.sum(axis=1) # normalization term: array([55, 50, 41])
Цель состоит в том, чтобы разделить каждую строку u
на соответствующее значение нормализующего члена u.sum(axis=1)
. Однако правила вещания не допускают непосредственного разделения двух терминов, поскольку u
имеет форму (3, 5)
, а массив нормализации имеет форму (3,)
. Как документация numpy указывает:
Два измерения совместимы, когда
- они равны или
- один из них 1
Так что u
можно разделить на массивы (3, 5)
, (1, 5)
, (3, 1)
, (5,)
или ()
, но не на u.sum(1)
формы (3,)
.
Именно поэтому оператор индекса newaxis
используется для вставки новой оси в термин нормализации, делая ее двумерной с формой (3, 1)
.
u.sum(axis=1)[:, np.newaxis] # array([[55], [50], [41]])
Наконец, функция softmax для строк будет иметь вид
def softmax(x):
assert x.dim == 2
u = np.exp(x)
z = u.sum(axis=1)[:, np.newaxis]
return u / z
Numeri c стабильность
Однако применение этой функции к большим значениям может быть численно нестабильным, поскольку np.exp(x)
может быть очень большой. Обратите внимание, что вычитание добавления любой константы не изменит результат благодаря условию нормализации.
Вот почему максимумы каждой строки m
равны вычитается, так что все значения ниже нуля перед применением показательной функции.