Пользовательские градиенты для Flux вместо использования Zygote AD - PullRequest
2 голосов
/ 16 апреля 2020

У меня есть модель машинного обучения, где градиенты для параметров модели являются аналитическими c, и нет необходимости в автоматическом c дифференцировании. Тем не менее, я все еще хочу иметь возможность использовать преимущества различных оптимизаторов в Flux без необходимости полагаться на Zygote для дифференциации. Вот некоторые фрагменты моего кода.

W = rand(Nh, N)
U = rand(N, Nh)
b = rand(N)
c = rand(Nh)

θ = Flux.Params([b, c, U, W])

opt = ADAM(0.01)

Затем у меня есть функция, которая вычисляет аналитические c градиенты параметров моей модели, θ.

function gradients(x) # x = one input data point or a batch of input data points
    # stuff to calculate gradients of each parameter
    # returns gradients of each parameter

Я тогда хочу чтобы можно было сделать что-то вроде следующего.

grads = gradients(x)
update!(opt, θ, grads)

Мой вопрос: какую форму / тип должна возвращать моя gradient(x) функция для выполнения update!(opt, θ, grads), и как мне это сделать

1 Ответ

2 голосов
/ 16 апреля 2020

Если вы не используете Params, тогда grads просто должен быть градиентом. Единственное требование состоит в том, что θ и grads имеют одинаковый размер.

Например, map((x, g) -> update!(opt, x, g), θ, grads), где θ == [b, c, U, W] и grads = [gradients(b), gradients(c), gradients(U), gradients(W)] (не совсем уверен, что gradients ожидает в качестве входных данных для вас ).

ОБНОВЛЕНИЕ: Но чтобы ответить на ваш оригинальный вопрос, gradients необходимо вернуть Grads объект, найденный здесь: https://github.com/FluxML/Zygote.jl/blob/359e586766129878ca0e56121037ed80afda6289/src/compiler/interface.jl#L88

Так что-то вроде

# within gradient function body assuming gb is the gradient w.r.t b
g = Zygote.Grads(IdDict())
g.grads[θ[1]] = gb # assuming θ[1] == b

Но не использовать Params, вероятно, проще для отладки. Единственная проблема заключается в том, что update! не будет работать с массивом параметров, но вы можете легко определить свой собственный:

function Flux.Optimise.update!(opt, xs::Tuple, gs)
    for (x, g) in zip(xs, gs)
        update!(opt, x, g)
    end
end

# use it like this
W = rand(Nh, N)
U = rand(N, Nh)
b = rand(N)
c = rand(Nh)

θ = (b, c, U, W)

opt = ADAM(0.01)
x = # generate input to gradients
grads = gradients(x) # return tuple (gb, gc, gU, gW)
update!(opt, θ, grads)

ОБНОВЛЕНИЕ 2:

Другой вариант по-прежнему использовать Zygote для получения градиентов, чтобы он автоматически устанавливал для вас объект Grads, но использовать пользовательское сопряжение, чтобы оно использовало вашу аналитическую функцию для вычисления сопряженного. Предположим, ваша модель ML определена как функция с именем f, поэтому f(x) возвращает выходные данные вашей модели для ввода x. Предположим также, что gradients(x) возвращает аналитические градиенты относительно x, как вы упоминали в своем вопросе. Тогда следующий код все еще будет использовать AD Zygote, который будет правильно заполнять объект Grads, но он будет использовать ваше определение вычисления градиентов для вашей функции f:

W = rand(Nh, N)
U = rand(N, Nh)
b = rand(N)
c = rand(Nh)

θ = Flux.Params([b, c, U, W])

f(x) = # define your model
gradients(x) = # define your analytical gradient

# set up the custom adjoint
Zygote.@adjoint f(x) = f(x), Δ -> (gradients(x),)

opt = ADAM(0.01)
x = # generate input to model
y = # output of model
grads = Zygote.gradient(() -> Flux.mse(f(x), y), θ)
update!(opt, θ, grads)

Обратите внимание, что я использовал Flux.mse как пример потери выше. Недостатком этого подхода является то, что функция Zygote gradient требует скалярного вывода. Если ваша модель передается с некоторой потерей, которая выдаст скалярное значение ошибки, тогда @adjoint - лучший подход. Это подходит для ситуации, когда вы выполняете стандартную ML, и единственное изменение заключается в том, что вы будете использовать sh для Zygote, чтобы аналитически вычислить градиент f с помощью вашей функции.

Если вы что-то делаете более сложный и не может использовать Zygote.gradient, тогда первый подход (без использования Params) является наиболее подходящим. Params действительно существует только для обратной совместимости со старым AD Flux, поэтому лучше по возможности его избегать.

...