Вы должны подумать о том, что на самом деле делает оператор max
?То есть:
- Возвращает или, лучше сказать, распространяет максимум.
И это именно то, что здесь происходит - требуется два илибольше тензоров и распространяется вперед (только) максимум.
Часто полезно взглянуть на короткий пример:
t1 = torch.rand(10, requires_grad=True)
t2 = torch.rand(10, requires_grad=True)
s1 = torch.sum(t1)
s2 = torch.sum(t2)
print('sum t1:', s1, 'sum t2:', s2)
m = torch.max(s1, s2)
print('max:', m, 'requires_grad:', m.requires_grad)
m.backward()
print('t1 gradients:', t1.grad)
print('t2 gradients:', t2.grad)
Этот код создает два случайныхТензор суммирует их и помещает их через функцию max.Затем на результат вызывается backward()
.
Давайте рассмотрим два возможных результата:
Результат 1 - сумма t1
больше:
sum t1: tensor(5.6345) sum t2: tensor(4.3965)
max: tensor(5.6345) requires_grad: True
t1 gradients: tensor([ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
t2 gradients: tensor([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
Результат 2 - сумма t2
больше:
sum t1: tensor(3.3263) sum t2: tensor(4.0517)
max: tensor(4.0517) requires_grad: True
t1 gradients: tensor([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
t2 gradients: tensor([ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
Как и следовало ожидать в случае s1
представляет максимальные градиенты будут рассчитаны для t1
.Аналогично, когда s2
- максимальные градиенты будут рассчитаны для t2
.
- Так же, как и на шаге вперед, обратное распространение распространяется назад через максимум.
Стоит отметить, чтодругие тензоры , которые не представляют максимум , все еще являются частью графика .Только градиенты устанавливаются на ноль.Если они не будут частью графика, вы получите None
в качестве градиента вместо нулевого вектора.
Вы можете проверить, что произойдет, если вы используете python- max
вместо torch.max
:
t1 = torch.rand(10, requires_grad=True)
t2 = torch.rand(10, requires_grad=True)
s1 = torch.sum(t1)
s2 = torch.sum(t2)
print('sum t1:', s1, 'sum t2:', s2)
m = max(s1, s2)
print('max:', m, 'requires_grad:', m.requires_grad)
m.backward()
print('t1 gradients:', t1.grad)
print('t2 gradients:', t2.grad)
Выход:
sum t1: tensor(4.7661) sum t2: tensor(4.4166)
max: tensor(4.7661) requires_grad: True
t1 gradients: tensor([ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
t2 gradients: None