diff --git a/src/UnweightedSamplingMulti.jl b/src/UnweightedSamplingMulti.jl index 2997947..691d52a 100644 --- a/src/UnweightedSamplingMulti.jl +++ b/src/UnweightedSamplingMulti.jl @@ -162,29 +162,32 @@ function recompute_skip!(s::MultiAlgRSWRSKIPSampler, n) return s end +macro quantile_fast(k) + block = Expr(:block) + firstv = quote + $(esc(:s)) = $(esc(:n)) * $(esc(:p)) + $(esc(:q)) = 1. - $(esc(:p)) + $(esc(:x)) = 1. + $(esc(:s)) / $(esc(:q)) + $(esc(:x)) > $(esc(:nt)) && return 1 + end + append!(block.args, firstv.args) + for i in 2:k + nextv = quote + $(esc(:s)) *= ($(esc(:n)) - $i) * $(esc(:p)) + $(esc(:q)) *= 1. - $(esc(:p)) + $(esc(:x)) += $(esc(:s)) / ($(esc(:q)) * $(factorial(i))) + $(esc(:x)) > $(esc(:nt)) && return $i + end + append!(block.args, nextv.args) + end + return block +end + @inline function choose(rng, n, p) z = exp(n*log1p(-p)) t = rand(rng, Uniform(z, 1.0)) - s = n*p - q = 1-p - x = z + z*s/q - x > t && return 1 - s *= (n-1)*p - q *= 1-p - x += (s*z/q)/2 - x > t && return 2 - s *= (n-2)*p - q *= 1-p - x += (s*z/q)/6 - x > t && return 3 - s *= (n-3)*p - q *= 1-p - x += (s*z/q)/24 - x > t && return 4 - s *= (n-4)*p - q *= 1-p - x += (s*z/q)/120 - x > t && return 5 + nt = t/z + @quantile_fast(8) return quantile(Binomial(n, p), t) end @@ -274,3 +277,5 @@ function ordvalue(s::MultiOrdAlgRSWRSKIPSampler) return s.value[sortperm(s.ord)] end end + +