Регистрация потерь во время обучения в Flux с использованием обратных вызовов - PullRequest
1 голос
/ 03 мая 2020

Я пытаюсь написать обратный вызов для функции train! в Flux. Мой код:

cb_loss = x -> push!(x, loss(x_train, y_train))
loss_vector = Vector{Float32}()

Flux.train!(loss, ps, train_data, opt, cb=cb_loss(loss_vector))

Это дает мне эту ошибку:

MethodError: objects of type Float32 are not callable

Stacktrace:
 [1] call(::Float32) at C:\Users\arjur\.julia\packages\Flux\Fj3bt\src\optimise\train.jl:36
 [2] foreach at .\abstractarray.jl:1920 [inlined]
 [3] #10 at C:\Users\arjur\.julia\packages\Flux\Fj3bt\src\optimise\train.jl:38 [inlined]
 [4] macro expansion at C:\Users\arjur\.julia\packages\Flux\Fj3bt\src\optimise\train.jl:93 [inlined]
 [5] macro expansion at C:\Users\arjur\.julia\packages\Juno\oLB1d\src\progress.jl:134 [inlined]
 [6] #train!#12(::Array{Float32,1}, ::typeof(Flux.Optimise.train!), ::typeof(loss), ::Zygote.Params, ::DataLoader, ::Descent) at C:\Users\arjur\.julia\packages\Flux\Fj3bt\src\optimise\train.jl:81
 [7] (::Flux.Optimise.var"#kw##train!")(::NamedTuple{(:cb,),Tuple{Array{Float32,1}}}, ::typeof(Flux.Optimise.train!), ::Function, ::Zygote.Params, ::DataLoader, ::Descent) at .\none:0
 [8] top-level scope at In[108]:1

Интересно, что он правильно добавляет первое значение к вектору, а затем вылетает, поэтому я предполагаю, что сообщение об ошибке связано с что.

Я проверил функцию вне функции train!, и она работает, так как я должен переписать эту функцию, чтобы записать потерю в векторе?

1 Ответ

0 голосов
/ 03 мая 2020

Похоже, вам нужно передать это так: cb=callback. Так что это может быть сделано либо с использованием глобальных переменных, либо путем определения обратного вызова следующим образом:

callback() = push!(loss_vector, loss(x_train, y_train))
...