Вы можете проверить Реализацию PyTorch в SGD , чтобы получить некоторые советы и основы этого кода.
Есть несколько вещей, которые должны ускорить вашу пользовательскую регуляризацию. Ниже приведена чистая версия (немного псевдокода, см. Оригинал) интересующих нас частей:
for p in group['params']:
if p.grad is None:
continue
d_p = p.grad.data
if weight_decay != 0:
d_p.add_(weight_decay, p.data)
p.data.add_(-group['lr'], d_p)
return loss
Кстати. Кажется, ваша реализация математически обоснована (поправьте меня, если я что-то пропустила) и эквивалентна PyTorch, но будет медленной.
Изменение только градиента
Обратите внимание, что вы выполняете регуляризацию явно во время прямого прохода. Это занимает много времени, более или менее, потому что:
- принимает параметры и перебирает их
- принимает их в степени
2
- суммирует все из них
- добавить в переменную, содержащую все предыдущие параметры (все это при динамическом создании графа и создании новых узлов).
Что делает pytorch
, это фокусируется только на обратном проходе как это все, что нужно. Это очень удобно, потому что:
- параметры должны быть загружены и повторены один раз в любом случае во время исправлений, выполняемых оптимизатором (в вашем случае они извлекаются дважды)
- без мощности
2
, поскольку градиент w**2
равен просто 2*w
(2
далее опущен, а L2
часто выражается как 1/2 * w **2
, чтобы сделать его проще и немного быстрее) - без накопления и создание дополнительных узлов графа
По сути, эта строка:
d_p.add_(weight_decay, p.data)
Изменяет градиент, добавляя p.data
(вес), умноженный на weight_decay
все выполненное на месте (уведомление d_p.add_
), и это все, что вам нужно сделать, чтобы выполнить L2
регуляризацию.
Наконец, эта строка:
p.data.add_(-group['lr'], d_p)
Обновляет веса с градиентом (изменяемым с уменьшением веса), используя стандартные Формула SGD (еще раз, на месте, чтобы быть максимально быстрой, по крайней мере, на уровне Python).
Ваша собственная реализация
Я бы посоветовал вам следовать аналогичной логике c для вашей регуляризации я Если вы хотите сделать это быстрее.
Вы можете скопировать PyTorch
реализацию SGD
и изменить только одну соответствующую строку. Это также даст вам функциональность оптимизатора PyTorch на тот случай, если он понадобится вам в ваших экспериментах.
Для L1
регуляризации (|w|
вместо w**2
) вам придется рассчитать его производную (что 1
для положительного случая, -1
для отрицательного и неопределенного для 0
(мы не можем этого иметь, поэтому должно быть ноль)).
Имея это в виду, мы можем написать weight_decay
вот так:
if weight_decay != 0:
d_p.add_(weight_decay, torch.sign(p.data))
torch.sign
возвращает 1
для положительных значений и -1
для отрицательных и 0
для ... да, 0
.
Надеюсь это помогает, точная реализация остается для вас (напишите мне в комментариях, если у вас есть какие-либо вопросы или проблемы).