Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,11 @@ sample(a::AbstractArray, wv::AbstractWeights) = sample(default_rng(), a, wv)
# Specialization for `UnitWeights`
sample(rng::AbstractRNG, wv::UnitWeights) = rand(rng, 1:length(wv))

function sample(rng::AbstractRNG, wv::AbstractWeights{<:Real,<:Real,<:SparseVector})
i = sample(rng, Weights(nonzeros(wv.values), sum(wv)))
return rowvals(wv.values)[i]
Comment on lines +610 to +611
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is unsafe - in general AbstractWeights are not required to have a values field. It's just a few AbstractWeights subtypes in StatsBase that have an (undocumented and internal) values field.

Copy link
Copy Markdown
Member

@nalimilan nalimilan Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So actually better define this method only for types defines in Base. Probably using:

for W in (AnalyticWeights, FrequencyWeights, ProbabilityWeights, Weights)
    @eval function sample(rng::AbstractRNG, wv::W{<:Real,<:Real,<:SparseVector})
    ...

(I'm saying this because AFAICT there's no public API which allows accessing the backing array. And anyway I'm not aware of custom AbstractWeights types defined elsewhere so we don't really care to apply this optimization to them.)

end

"""
direct_sample!([rng], a::AbstractArray, wv::AbstractWeights, x::AbstractArray)

Expand Down
7 changes: 6 additions & 1 deletion test/wsampling.jl
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests are insufficient - since the method is implemented for AbstractWeights, to be sure it works not only for Weights we should test all subtypes implemented in StatsBase and a custom subtype of AbstractWeights.

Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using StatsBase
using Random, Test, OffsetArrays
using Random, Test, OffsetArrays, SparseArrays

Random.seed!(1234)

Expand Down Expand Up @@ -41,6 +41,7 @@ for wv in (
weights([0.2, 0.8, 0.4, 0.6]),
weights([2, 8, 4, 6]),
weights(Float32[0.2, 0.8, 0.4, 0.6]),
weights(sparsevec([0, 8, 0, 6])),
Weights(Float32[0.2, 0.8, 0.4, 0.6], 2),
Weights([2, 8, 4, 6], 20.0),
)
Comment thread
AntonOresten marked this conversation as resolved.
Expand Down Expand Up @@ -130,6 +131,10 @@ test_rng_use(efraimidis_aexpj_wsample_norep!, 4:7, wv, zeros(Int, 2))
a = sample(4:7, wv, 3; replace=false, ordered=false)
check_wsample_norep(a, (4, 7), wv, -1; ordered=false)

wv_sparse = weights(sparsevec([0, 8, 4, 6]))
a_sparse = sample(4:7, wv_sparse, 3; replace=false, ordered=false)
check_wsample_norep(a_sparse, (4, 7), wv_sparse, -1; ordered=false)

for rev in (true, false), T in (Int, Int16, Float64, Float16, BigInt, ComplexF64, Rational{Int})
r = rev ? reverse(4:7) : (4:7)
r = T===Int ? r : T.(r)
Expand Down
Loading