Ограничение подписей функций при использовании ForwardDiff в Julia - PullRequest
0 голосов
/ 06 мая 2019

Я пытаюсь использовать ForwardDiff в библиотеке, где почти все функции ограничены для использования только в Float. Я хочу обобщить эти сигнатуры функций, чтобы можно было использовать ForwardDiff, оставаясь при этом достаточно ограничительным, чтобы функции принимали только числовые значения, а не такие вещи, как Dates. У меня есть множество функций с одним и тем же именем, но разных типов (то есть функций, которые принимают «время» в виде числа с плавающей запятой или даты с тем же именем функции) и не хотят во всех случаях удалять классификаторы типов.

Минимальный рабочий пример

using ForwardDiff
x = [1.0, 2.0, 3.0, 4.0 ,5.0]
typeof(x) # Array{Float64,1}
function G(x::Array{Real,1})
    return sum(exp.(x))
end
function grad_F(x::Array)
  return ForwardDiff.gradient(G, x)
end
G(x) # Method Error
grad_F(x) # Method error

function G(x::Array{Float64,1})
    return sum(exp.(x))
end
G(x) # This works
grad_F(x) # This has a method error

function G(x)
    return sum(exp.(x))
end
G(x) # This works
grad_F(x) # This works
# But now I cannot restrict the function G to only take numeric arrays and not for instance arrays of Dates.

Есть ли способ ограничить функции, чтобы они принимали только числовые значения (Ints и Float) и любые структуры двойных чисел, которые использует ForwardDiff, но не разрешали символы, даты и т. Д.

Ответы [ 2 ]

2 голосов
/ 06 мая 2019

ForwardDiff.Dual является подтипом абстрактного типа Real. Проблема, однако, заключается в том, что параметры типа Джулии являются инвариантными, а не ковариантными. Затем следующее возвращает false.

# check if `Array{Float64, 1}` is a subtype of `Array{Real, 1}`
julia> Array{Float64, 1} <: Array{Real, 1}
false

Это делает ваше определение функции

function G(x::Array{Real,1})
    return sum(exp.(x))
end

неверно (не подходит для вашего использования). Вот почему вы получаете следующую ошибку.

julia> G(x)
ERROR: MethodError: no method matching G(::Array{Float64,1})

Правильное определение должно быть скорее

function G(x::Array{<:Real,1})
    return sum(exp.(x))
end

или если вам как-то нужен простой доступ к конкретному типу элемента массива

 function G(x::Array{T,1}) where {T<:Real}
     return sum(exp.(x))
 end

То же самое относится и к вашей grad_F функции.

Может оказаться полезным прочитать соответствующий раздел документации Julia для типов.


Возможно, вы также захотите напечатать аннотации ваших функций для типа AbstractArray{<:Real,1} вместо Array{<:Real, 1}, чтобы ваши функции могли работать с другими типами массивов, такими как StaticArrays, OffsetArrays и т. Д., Без необходимости переопределения.

1 голос
/ 06 мая 2019

Это будет принимать любой тип массива, параметризованный любым числом:

function foo(xs::AbstractArray{<:Number})
  @show typeof(xs)
end

или:

function foo(xs::AbstractArray{T}) where T<:Number
  @show typeof(xs)
end

В случае, если вам нужно обратиться к параметру типа Tвнутри функции тела.

x1 = [1.0, 2.0, 3.0, 4.0 ,5.0]
x2 = [1, 2, 3,4, 5]
x3 = 1:5
x4 = 1.0:5.0
x5 = [1//2, 1//4, 1//8]

xss = [x1, x2, x3, x4, x5]

function foo(xs::AbstractArray{T}) where T<:Number
  @show xs typeof(xs) T
  println()
end

for xs in xss
  foo(xs)
end

Выходы:

xs = [1.0, 2.0, 3.0, 4.0, 5.0]
typeof(xs) = Array{Float64,1}
T = Float64

xs = [1, 2, 3, 4, 5]
typeof(xs) = Array{Int64,1}
T = Int64

xs = 1:5
typeof(xs) = UnitRange{Int64}
T = Int64

xs = 1.0:1.0:5.0
typeof(xs) = StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}
T = Float64

xs = Rational{Int64}[1//2, 1//4, 1//8]
typeof(xs) = Array{Rational{Int64},1}
T = Rational{Int64}

Вы можете запустить пример кода здесь: https://repl.it/@SalchiPapa/Restricting-function-signatures-in-Julia

...