С этим кодом есть две проблемы.
- Первый уровень реализации заключается в том, что вы используете операцию на месте, которая, как правило, плохо работает с autograd.Вместо
relu1[relu1 > self.act_max] = self.act_max
вы должны использовать операцию «без места», например
relu1 = torch.where(relu1 > self.act_max, self.act_max, relu1)
Другой является более общим: нейронные сети обычно обучаются с использованием методов градиентного спуска, и пороговые значения могут не иметь градиента - функция потерь не дифференцируема по отношению к пороговым значениям.
В вашей модели вывы используете грязный хакерранд (пишете ли вы как есть или используете torch.where
) - model.act_max.grad.data
определяется только потому, что для некоторых элементов их значение установлено на model.act_max
.Но этот градиент ничего не знает о , почему им присвоено это значение.Чтобы сделать вещи более конкретными, давайте определим операцию отсечки C(x, t)
, которая определяет, будет ли x
выше или ниже порога t
C(x, t) = 1 if x < t else 0
и запишем вашу операцию отсечения как продукт
clip(x, t) = C(x, t) * x + (1 - C(x, t)) * t
затем вы можете видеть, что порог t
имеет двоякое значение: он контролирует, когда необходимо срезать (внутри C
), и контролирует значение выше среза (трейлинг).t
).Поэтому мы можем обобщить операцию как
clip(x, t1, t2) = C(x, t1) * x + (1 - C(x, t1)) * t2
Проблема с вашей операцией состоит в том, что она дифференцируема только по t2
, но не t1
.Ваше решение связывает их вместе, так что t1 == t2
, но все равно это тот случай, когда градиентный спуск будет действовать так, как если бы не было изменения порога, а только изменение значения выше порога.
ДляПо этой причине в общем случае ваша операция порогового значения, возможно, не учитывает ценность, на которую вы надеетесь.Об этом следует помнить при разработке своих операций, но не гарантию сбоя - фактически, если вы рассмотрите стандарт ReLU
для смещенного вывода некоторой линейной единицы, мы получим похожую картину.Мы определяем операцию отсечения H
H(x, t) = 1 if x > t else 0
и ReLU
как
ReLU(x + b, t) = (x + b) * H(x + b, t) = (x + b) * H(x, t - b)
, где мы можем снова обобщить до
ReLU(x, b, t) = (x + b) * H(x, t)
и снова мы можем только узнать b
и t
неявно следует b
.Все же, похоже, работает :) 1074 *