Алгоритм уменьшения десятичных чисел в шаблоне 0,2 -> 0,1 -> 0,09 -> 0,08 -> ... -> 0,02 -> 0,01 -> 0,009 - PullRequest
0 голосов
/ 27 апреля 2018

Я пытаюсь автоматизировать процесс снижения скорости обучения нейронной сети. Я хотел бы написать функцию, которая вызывается, если потери нейронной сети не уменьшаются в течение n эпох.

Эта функция будет принимать текущую скорость обучения в качестве параметра, а затем будет уменьшать ее на 0,1, 0,01, 0,001 и т. Д., В зависимости от ее текущего значения (количества значащих цифр, которые она имеет в настоящее время). Это дало бы паттерн распада 0,2 -> 0,1 -> 0,09 -> 0,08 -> ... -> 0,02 -> 0,01 -> 0,009

В качестве ориентира у меня в настоящее время скорость обучения снижается, как показано ниже, начиная с 0,1:

def decayLearningRate(learningRate):
    return learningRate ** 2

Однако эти прыжки слишком велики. Какой элегантный способ добиться того, что я предложил?

Примечание. Скорость обучения всегда начинается с 0,1 или менее.

Ответы [ 2 ]

0 голосов
/ 27 апреля 2018

Я звоню x-y problem по этому вопросу. Посмотрите на линии для типичных моделей затухания скорости обучения; что ты делаешь не гладко. В каждой последовательности из 10 * N эпох (где N - ваш интервал нетерпения), вы начинаете с 10% -ого затухания, затем ускоряете затухание до 11%, 12,5%, ... 50%, после чего вы сбрасываете значение до 10%.

Скорее, просто выберите пропорцию, которая подходит для вашего приложения. Классически, различные приложения использовали что-нибудь от 10% до 3 (или даже 10):

return learning_rate * 0.90
return learning_rate / 3
return learning_rate /10

Как и во всех неуказанных приложениях, вам придется поэкспериментировать с вашим N и используемым вами коэффициентом, чтобы увидеть, что работает лучше для вас.

0 голосов
/ 27 апреля 2018

Я не совсем уверен, что это элегантное решение , но это способ решить задачу:

from decimal import Decimal

value = Decimal('0.3')
n = 15

for i in range(n):
    last_digit = value.as_tuple().digits[-1]

    if last_digit == 1:
        value -= value / 10
    else:
        value -= value / last_digit
    print(value)

Выход:

0.2
0.1
0.09
0.08
0.07
0.06
0.05
0.04
0.03
0.02
0.01
0.009
0.008
0.007
0.006
...