Использование квантиля в Flux (Юлия) в функции потерь - PullRequest
6 голосов
/ 16 января 2020

Я пытаюсь использовать квантиль в функции потерь для тренировки! (для некоторой устойчивости, например, наименьших усеченных квадратов), но он мутирует массив, и Zygote выдает ошибку Mutating arrays is not supported, исходящую из sort!. Ниже приведен простой пример (содержание, конечно, не имеет смысла):

using Flux, StatsBase
xdata = randn(2, 100)   
ydata = randn(100)

model = Chain(Dense(2,10), Dense(10, 1))


function trimmedLoss(x,y; trimFrac=0.f05)
        yhat = model(x)
        absRes = abs.(yhat .- y) |> vec
        trimVal = quantile(absRes, 1.f0-trimFrac) 
        s = sum(ifelse.(absRes .> trimVal,  0.f0 , absRes ))/(length(absRes)*(1.f0-trimFrac))
        #s = sum(absRes)/length(absRes)   # using this and commenting out the two above works (no surprise)    
end

println(trimmedLoss(xdata, ydata)) #works ok

Flux.train!(trimmedLoss, params(model), zip([xdata], [ydata]), ADAM())

println(trimmedLoss(xdata, ydata)) #changed loss?

Это все в Flux 0.10 с Julia 1.2

Заранее спасибо за любые подсказки или обходные пути!

1 Ответ

5 голосов
/ 20 января 2020

В идеале мы бы определили пользовательское присоединение для quantile, чтобы это работало «из коробки». (Не стесняйтесь, чтобы открыть выпуск , чтобы напомнить нам сделать это.)

В то же время есть быстрый обходной путь. Это на самом деле сортировка, которая вызывает проблемы здесь, так что если вы сделаете quantile(xs, p, sorted=true), это будет работать. Очевидно, что для получения правильных результатов требуется сортировка xs, поэтому вам может потребоваться использовать quantile(sort(xs), ...).

В зависимости от версии Zygote вам также может потребоваться сопряжение для sort. Это довольно просто:

julia> using Zygote: @adjoint

julia> @adjoint function sort(x)
         p = sortperm(x)
         x[p], x̄ -> (x̄[invperm(p)],)
       end

julia> gradient(x -> quantile(sort(x), 0.5, sorted=true), [1, 2, 3, 3])
([0.0, 0.5, 0.5, 0.0],)

Мы сделаем это встроенным в следующем выпуске Zygote, но сейчас, если вы добавите это в свой скрипт, он заработает ваш код.

...