-
Notifications
You must be signed in to change notification settings - Fork 440
Dispatch for drawing multiples #1985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
d30e355
4aef890
3fc31ce
5f406ab
c10ef3e
df17b50
7501a5a
832519f
ee17d46
8a9c15a
4fed744
0d303ba
4534327
97e3b25
cbe3df6
ad53e2d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -477,6 +477,34 @@ rand(rng::AbstractRNG, s::MixtureSampler{Univariate}) = | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rand(rng::AbstractRNG, d::MixtureModel{Univariate}) = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rand(rng, component(d, rand(rng, d.prior))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| function rand(rng::AbstractRNG, d::MixtureModel{Univariate}, n::Int) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the wrong dispatch, isn't it? If one wants to draw multiple samples from a distribution However, generally for univariate distributions one also shouldn't define Distributions.jl/src/genericrand.jl Line 35 in f1ff9e8
Distributions.jl/src/univariates.jl Lines 141 to 150 in f1ff9e8
So AFAICT we should only define
Suggested change
Alternatively, if we never want to use
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| counts = rand(rng, Multinomial(n, probs(d.prior))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Find the component with the maximum count to minimize resizing | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| max_count_idx = argmax(counts) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| max_count = counts[max_count_idx] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Sample from the component with maximum count first and use it directly | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| x = rand(rng, component(d, max_count_idx), max_count) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Resize to the full size and continue with other components | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| resize!(x, n) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| offset = max_count | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for i in eachindex(counts) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if i != max_count_idx | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ni = counts[i] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if ni > 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| c = component(d, i) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| last_offset = offset + ni - 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rand!(rng, c, @view(x[(begin+offset):(begin+last_offset)])) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| offset = last_offset + 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+483
to
+504
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the in-place method, it seems this could be simplified to
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return shuffle!(rng, x) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # multivariate mixture sampler for a vector | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _rand!(rng::AbstractRNG, s::MixtureSampler{Multivariate}, x::AbstractVector{<:Real}) = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rand!(rng, s.csamplers[rand(rng, s.psampler)], x) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -233,6 +233,62 @@ function rand(rng::AbstractRNG, d::Truncated) | |||||||
| end | ||||||||
| end | ||||||||
|
|
||||||||
| function rand(rng::AbstractRNG, d::Truncated, n::Int) | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
And if there's any precomputations that could be factored out (doesn't seem to be the case?), then we should think about defining a dedicated sampler. |
||||||||
| d0 = d.untruncated | ||||||||
| tp = d.tp | ||||||||
| lower = d.lower | ||||||||
| upper = d.upper | ||||||||
|
|
||||||||
| # Use the same three regimes as the scalar version | ||||||||
| if tp > 0.25 | ||||||||
| # Regime 1: Rejection sampling with batch optimization | ||||||||
| # Get the correct type and memory by sampling from the untruncated distribution | ||||||||
| samples = rand(rng, d0, n) | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| n_collected = 0 | ||||||||
| max_batch = 0 | ||||||||
| batch_buffer = Vector{eltype(samples)}() | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A separate batch buffer seems unnecessary, in particular the resizing might be inefficient? Instead of copying from a separate vector we could just use the output vector and move samples around there and keep track of the last accepted index. |
||||||||
| while n_collected < n | ||||||||
| n_remaining = n - n_collected | ||||||||
| n_expected = n_remaining / tp | ||||||||
| δn_expected = sqrt(n_remaining * tp * (1 - tp)) | ||||||||
| n_batch_f = n_expected + 3δn_expected | ||||||||
| n_batch = ceil(Int, n_batch_f) | ||||||||
| if n_batch > max_batch | ||||||||
| resize!(batch_buffer, n_batch) | ||||||||
| max_batch = n_batch | ||||||||
| end | ||||||||
| rand!(rng, d0, batch_buffer) | ||||||||
| for i in 1:n_batch | ||||||||
| s = batch_buffer[i] | ||||||||
| if _in_closed_interval(s, lower, upper) | ||||||||
| n_collected += 1 | ||||||||
| samples[n_collected] = s | ||||||||
| n_collected == n && break | ||||||||
| end | ||||||||
| end | ||||||||
| end | ||||||||
| return samples | ||||||||
| elseif tp > sqrt(eps(typeof(float(tp)))) | ||||||||
| # Regime 2: Quantile-based sampling | ||||||||
| # Sample one value first to determine the correct type | ||||||||
| sample_type = typeof(quantile(d0, d.lcdf + rand(rng) * d.tp)) | ||||||||
| samples = Vector{sample_type}(undef, n) | ||||||||
|
Comment on lines
+273
to
+275
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| for i in 1:n | ||||||||
| samples[i] = quantile(d0, d.lcdf + rand(rng) * d.tp) | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should probably at least use a |
||||||||
| end | ||||||||
| return samples | ||||||||
| else | ||||||||
| # Regime 3: Log-space computation | ||||||||
| # Sample one value first to determine the correct type | ||||||||
| sample_type = typeof(invlogcdf(d0, logaddexp(d.loglcdf, d.logtp - randexp(rng)))) | ||||||||
| samples = Vector{sample_type}(undef, n) | ||||||||
|
Comment on lines
+282
to
+284
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| for i in 1:n | ||||||||
| samples[i] = invlogcdf(d0, logaddexp(d.loglcdf, d.logtp - randexp(rng))) | ||||||||
| end | ||||||||
| return samples | ||||||||
| end | ||||||||
| end | ||||||||
|
|
||||||||
| ## show | ||||||||
|
|
||||||||
| function show(io::IO, d::Truncated) | ||||||||
|
|
||||||||
Uh oh!
There was an error while loading. Please reload this page.