Zygote.Hessian: мутирующие массивы не поддерживаются - PullRequest
0 голосов
/ 05 августа 2020

Я пытаюсь вычислить лапласиан нейронной сети. Это мой код:

using Flux
using Zygote

model = Chain(Dense(2,5,sigmoid), Dense(5,1))

function laplace(x)
    a, b = size(x)
    Δ = Zygote.Buffer(zeros(b))
    deriv2 = sum(Diagonal(ones(a*b)).*Zygote.hessian(v -> sum(model(v)), x), dims=1)
    for i=1:b
        for j=1:a
            Δ[i] += deriv2[(i-1)*a+j]
        end
    end
    return copy(Δ)
end

gradient(x -> sum(laplace(x)), rand(2,5))

Я увижу ту же ошибку, даже если я определю такую ​​функцию:

function function(x)
    return sum(Zygote.hessian(v -> sum(model(v)), x))
end

gradient(x -> function(x), rand(2,5))

Почему я получаю эту ошибку?

...