Тип нестабильности в случайных данных с заданным распределением - PullRequest
4 голосов
/ 10 марта 2020

Я хочу сгенерировать данные из линейной модели с шумом (Y = X * w + e), где я могу указать распределения входного вектора X и скалярного шума e. Для этого я указываю нижеприведенную структуру

using Distributions

struct NoisyLinearDataGenerator
    x_dist::ContinuousMultivariateDistribution
    noise_dist::ContinuousUnivariateDistribution
    weights::Vector{Float64}
end

и функцию для генерации из нее N точек:

function generate(nl::NoisyLinearDataGenerator, N)
    x = rand(nl.x_dist, N)'
    e = rand(nl.noise_dist, N)
    return x, x*nl.weights + e
end

Кажется, что это работает, но не стабильно, например

nl = NoisyLinearDataGenerator(MvNormal(5, 1.0), Normal(), ones(5))

@code_warntype generate(nl,1)

приводит к

Variables
  #self#::Core.Compiler.Const(generate, false)
  nl::NoisyLinearDataGenerator
  N::Int64
  x::Any
  e::Any

Body::Tuple{Any,Any}
1 ─ %1  = Base.getproperty(nl, :x_dist)::Distribution{Multivariate,Continuous}
│   %2  = Main.rand(%1, N)::Any
│         (x = Base.adjoint(%2))
│   %4  = Base.getproperty(nl, :noise_dist)::Distribution{Univariate,Continuous}
│         (e = Main.rand(%4, N))
│   %6  = x::Any
│   %7  = x::Any
│   %8  = Base.getproperty(nl, :weights)::Array{Float64,1}
│   %9  = (%7 * %8)::Any
│   %10 = (%9 + e)::Any
│   %11 = Core.tuple(%6, %10)::Tuple{Any,Any}
└──       return %11

Я не уверен, почему это так, поскольку я ожидаю, что тип выборочных данных будет определен с использованием ContinuousMultivariateDistribution и ContinuousUnivariateDistribution.

Что здесь приводит к нестабильности типов и как должна выглядеть стабильная реализация типов?

1 Ответ

6 голосов
/ 10 марта 2020

Проблема в том, что ContinuousMultivariateDistribution и ContinuousUnivariateDistribution являются абстрактными типами. Хотя ваши знания статистики говорят о том, что они, вероятно, должны возвращать Float64, на уровне языка нет гарантии, что кто-то не реализует, скажем, ContinuousUnivariateDistribution, который возвращает какой-то другой объект. Поэтому компилятор не может знать, что all ContinuousUnivariateDistribution производит какой-либо конкретный тип.

Например, я мог бы написать:

struct BadDistribution <: ContinuousUnivariateDistribution end
Base.rand(::BadDistribution, ::Integer) = nothing

Теперь вы можете сделать NoisyLinearDataGenerator, содержащий BadDistribution в качестве x_dist. Каким будет тип вывода в таком случае?

Другими словами, вывод generate просто не может прогнозироваться только по его типам ввода.

Решить для этого вам нужно либо указать специфицированные c дистрибутивы для вашего нового типа, либо сделать параметр нового типа параметри c. В Julia, когда у нас есть поле типа, которое не может быть указано для конкретного типа, мы обычно оставляем его как параметр типа. Таким образом, одним из возможных решений является следующее:

using Distributions

struct NoisyLinearDataGenerator{X,N}
    x_dist::X
    noise_dist::N
    weights::Vector{Float64}

    function NoisyLinearDataGenerator{X,N}(x::X, n::N, w::Vector{Float64}) where {
                                    X <: ContinuousMultivariateDistribution,
                                    N <: ContinuousUnivariateDistribution}
        return new{X,N}(x,n,w)
    end
end

function NoisyLinearDataGenerator(x::X, n::N, w::Vector{Float64}) where {
                                X <: ContinuousMultivariateDistribution,
                                N <: ContinuousUnivariateDistribution}
    return NoisyLinearDataGenerator{X,N}(x,n,w)
end

function generate(nl::NoisyLinearDataGenerator, N)
    x = rand(nl.x_dist, N)'
    e = rand(nl.noise_dist, N)
    return x, x*nl.weights + e
end
nl = NoisyLinearDataGenerator(MvNormal(5, 1.0), Normal(), ones(5))

Здесь тип nl равен NoisyLinearDataGenerator{MvNormal{Float64,PDMats.ScalMat{Float64},FillArrays.Zeros{Float64,1,Tuple{Base.OneTo{Int64}}}},Normal{Float64}} (да, я знаю, ужасно читать), но его тип содержит всю информацию, необходимую для компилятора чтобы полностью предсказать тип вывода generate.

...