Как напечатать «фактический» курс обучения в Adadelta в pytorch - PullRequest
0 голосов
/ 21 ноября 2018

Короче говоря :

Я не могу нарисовать кривую lr / epoch при использовании оптимизатора adadelta в pytorch, потому что optimizer.param_groups[0]['lr'] всегда возвращает одно и то же значение.

Подробно :

Adadelta может динамически адаптироваться во времени, используя только информацию первого порядка, и имеет минимальные вычислительные издержки за пределами ванильного стохастического градиентного спуска [1].

В pytorch, источниккод Adadelta здесь https://pytorch.org/docs/stable/_modules/torch/optim/adadelta.html#Adadelta

Поскольку он не требует ручной настройки скорости обучения, насколько мне известно, нам не нужно устанавливать расписание после объявления оптимизатора

self.optimizer = torch.optim.Adadelta(self.model.parameters(), lr=1)

Способ проверки скорости обучения:

current_lr = self.optimizer.param_groups[0]['lr']

Проблема в том, что он всегда возвращает 1 (начальный lr).

Может кто-нибудь сказать мне, как я могу получить истинную скорость обучения, чтобы я мог нарисовать кривую lr / epch?

[1] https://arxiv.org/pdf/1212.5701.pdf

1 Ответ

0 голосов
/ 21 ноября 2018

Проверка: self.optimizer.state.Это оптимизируется с помощью lr и используется в процессе оптимизации.

Из документации lr - это просто:

lr (число с плавающей запятой, необязательно): коэффициент, который масштабирует дельту до ее применения к параметрам (по умолчанию: 1,0)

https://pytorch.org/docs/stable/_modules/torch/optim/adadelta.html

Отредактировано: вы можете найти значения acc_delta в значениях self.optimizer.state, но вам нужно просмотреть словари, содержащиеся в этом словаре:

dict_with_acc_delta = [self.optimizer.state[i] for i in self.optimizer.state.keys() if "acc_delta" in self.optimizer.state[i].keys()]
acc_deltas = [i["acc_delta"] for i in dict_with_acc_delta]

У меня естьвосемь слоев и форм элементов в списке acc_deltas следующие

[torch.Size([25088]),
 torch.Size([25088]),
 torch.Size([4096, 25088]),
 torch.Size([4096]),
 torch.Size([1024, 4096]),
 torch.Size([1024]),
 torch.Size([102, 1024]),
 torch.Size([102])]
...