Генерировать фиксированное количество результатов, удовлетворяющих некоторым условиям - PullRequest
1 голос
/ 08 ноября 2019

Есть ли способ вызывать функцию несколько раз, пока не будет собрано определенное количество результатов с определенными атрибутами?

Пример:

Взяв функцию rand(), я хочу сохранитьтолько результаты, которые выше >= 0.5 и генерируют 10 выборок.

Конечно, я мог бы сделать что-то вроде rand(Truncated(Uniform(0,1), 0, 0.5), 10), используя пакет Distributions, но я ищу более абстрактное решение.

Не оченьрешение для защиты:

До сих пор я нашел только следующее:

using IterTools
cond(x) = ...
f() = ...
gen = IterTools.repeatedly(f)
samples = collect(IterTools.take((n for n in gen if cond(n)), size))

Для приведенного выше примера:

using Distributions
using IterTools
cond(x) = x >= 0.5
f() = rand(Uniform(0,1))
gen = IterTools.repeatedly(f)
rnd_nodes = collect(IterTools.take((n for n in gen if cond(n)), 10))

Но есть ли, возможно, более короткий / лаконичный /более читабельный способ сделать это?

Ответы [ 2 ]

2 голосов
/ 08 ноября 2019

Есть ли способ повторно вызывать функцию, пока не будет собрано определенное количество результатов с определенными атрибутами?

Конечно, используйте простой цикл while.

using BenchmarkTools, Transducers

function loop()
    res = Vector{Float64}(undef,10)
    i = 0
    while i<10
        r = rand()
        if r >= 0.5
            i+=1
            res[i] = r
        end
    end
    return res
end

function transducer() # @Jun Tian's answer
    t = Map(_ -> rand()) |> Filter(x -> x >=0.5) |> Take(10)
    return collect(t, Iterators.repeated(1))
end

@btime transducer(); # 687ns
@btime loop(); # 170ns
@btime 0.5 .* rand(10) .+ 0.5; # 86ns

Для сравнения я добавил 0.5 .* rand(10) .+ 0.5, который просто дает вам то, что вы хотите, не повторяя и не полагаясь на «удачу».

0 голосов
/ 08 ноября 2019

Попробуйте Transducers.jl

julia> using Transducers

julia> t = Map(_ -> rand()) |> Filter(x -> x >=0.5) |> Take(10)
Map(Main.λ❓) |>
    Filter(Main.λ❓) |>
    Take(10)

julia> collect(t, Iterators.repeated(1))
10-element Array{Float64,1}:
 0.6125615651973046
 0.9271858603504375
 0.8218768467739419
 0.5719380767545377
 0.7073831906599655
 0.5228490007486046
 0.9929437973392725
 0.6935716395158282
 0.6663379802812248
 0.6149007269488846
...