diff --git a/src/sampling.jl b/src/sampling.jl index c6294c979..309a4427c 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -606,6 +606,11 @@ sample(a::AbstractArray, wv::AbstractWeights) = sample(default_rng(), a, wv) # Specialization for `UnitWeights` sample(rng::AbstractRNG, wv::UnitWeights) = rand(rng, 1:length(wv)) +function sample(rng::AbstractRNG, wv::AbstractWeights{<:Real,<:Real,<:SparseVector}) + i = sample(rng, Weights(nonzeros(wv.values), sum(wv))) + return rowvals(wv.values)[i] +end + """ direct_sample!([rng], a::AbstractArray, wv::AbstractWeights, x::AbstractArray) diff --git a/test/wsampling.jl b/test/wsampling.jl index efe9a608f..89019a3e2 100644 --- a/test/wsampling.jl +++ b/test/wsampling.jl @@ -1,5 +1,5 @@ using StatsBase -using Random, Test, OffsetArrays +using Random, Test, OffsetArrays, SparseArrays Random.seed!(1234) @@ -41,6 +41,7 @@ for wv in ( weights([0.2, 0.8, 0.4, 0.6]), weights([2, 8, 4, 6]), weights(Float32[0.2, 0.8, 0.4, 0.6]), + weights(sparsevec([0, 8, 0, 6])), Weights(Float32[0.2, 0.8, 0.4, 0.6], 2), Weights([2, 8, 4, 6], 20.0), ) @@ -130,6 +131,10 @@ test_rng_use(efraimidis_aexpj_wsample_norep!, 4:7, wv, zeros(Int, 2)) a = sample(4:7, wv, 3; replace=false, ordered=false) check_wsample_norep(a, (4, 7), wv, -1; ordered=false) +wv_sparse = weights(sparsevec([0, 8, 4, 6])) +a_sparse = sample(4:7, wv_sparse, 3; replace=false, ordered=false) +check_wsample_norep(a_sparse, (4, 7), wv_sparse, -1; ordered=false) + for rev in (true, false), T in (Int, Int16, Float64, Float16, BigInt, ComplexF64, Rational{Int}) r = rev ? reverse(4:7) : (4:7) r = T===Int ? r : T.(r)