From ff780fdc8111a561f6a5e31874bb48cd67b17613 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Tue, 26 Nov 2024 00:42:13 +0100 Subject: [PATCH 1/5] Specialize `sample` for sparse weights --- src/sampling.jl | 5 +++++ test/wsampling.jl | 3 ++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/sampling.jl b/src/sampling.jl index c6294c979..dfc29b087 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -606,6 +606,11 @@ sample(a::AbstractArray, wv::AbstractWeights) = sample(default_rng(), a, wv) # Specialization for `UnitWeights` sample(rng::AbstractRNG, wv::UnitWeights) = rand(rng, 1:length(wv)) +function sample(rng::AbstractRNG, wv::AbstractWeights{<:Real,T,V}) where {T<:Real,V<:SparseVector{T}} + i = sample(rng, Weights(nonzeros(wv.values), sum(wv))) + return SparseArrays.nonzeroinds(wv.values)[i] +end + """ direct_sample!([rng], a::AbstractArray, wv::AbstractWeights, x::AbstractArray) diff --git a/test/wsampling.jl b/test/wsampling.jl index efe9a608f..9b7ebc155 100644 --- a/test/wsampling.jl +++ b/test/wsampling.jl @@ -1,5 +1,5 @@ using StatsBase -using Random, Test, OffsetArrays +using Random, Test, OffsetArrays, SparseArrays Random.seed!(1234) @@ -41,6 +41,7 @@ for wv in ( weights([0.2, 0.8, 0.4, 0.6]), weights([2, 8, 4, 6]), weights(Float32[0.2, 0.8, 0.4, 0.6]), + weights(sparsevec([0, 8, 0, 6])), Weights(Float32[0.2, 0.8, 0.4, 0.6], 2), Weights([2, 8, 4, 6], 20.0), ) From 3b442b34cf0956d2aa7570547e0a0b8182c295d9 Mon Sep 17 00:00:00 2001 From: Anton Oresten Date: Wed, 7 Jan 2026 16:34:18 +0100 Subject: [PATCH 2/5] Update src/sampling.jl Co-authored-by: Milan Bouchet-Valat --- src/sampling.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sampling.jl b/src/sampling.jl index dfc29b087..854e64b8e 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -606,7 +606,7 @@ sample(a::AbstractArray, wv::AbstractWeights) = sample(default_rng(), a, wv) # Specialization for `UnitWeights` sample(rng::AbstractRNG, wv::UnitWeights) = rand(rng, 1:length(wv)) -function sample(rng::AbstractRNG, wv::AbstractWeights{<:Real,T,V}) where {T<:Real,V<:SparseVector{T}} +function sample(rng::AbstractRNG, wv::AbstractWeights{<:Real,<:Real,<:SparseVector}) i = sample(rng, Weights(nonzeros(wv.values), sum(wv))) return SparseArrays.nonzeroinds(wv.values)[i] end From db126d10e8d1345148bfe1c8ce92800e57fc755e Mon Sep 17 00:00:00 2001 From: Anton Oresten Date: Wed, 7 Jan 2026 16:35:56 +0100 Subject: [PATCH 3/5] Update src/sampling.jl Co-authored-by: Milan Bouchet-Valat --- src/sampling.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sampling.jl b/src/sampling.jl index 854e64b8e..309a4427c 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -608,7 +608,7 @@ sample(rng::AbstractRNG, wv::UnitWeights) = rand(rng, 1:length(wv)) function sample(rng::AbstractRNG, wv::AbstractWeights{<:Real,<:Real,<:SparseVector}) i = sample(rng, Weights(nonzeros(wv.values), sum(wv))) - return SparseArrays.nonzeroinds(wv.values)[i] + return rowvals(wv.values)[i] end """ From b5f5af162649f35faf7874fa0719da919f5d01d6 Mon Sep 17 00:00:00 2001 From: Anton Oresten Date: Wed, 7 Jan 2026 17:05:26 +0100 Subject: [PATCH 4/5] Add test for sampling without replacement --- test/wsampling.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/wsampling.jl b/test/wsampling.jl index 9b7ebc155..8f0b3053d 100644 --- a/test/wsampling.jl +++ b/test/wsampling.jl @@ -131,6 +131,10 @@ test_rng_use(efraimidis_aexpj_wsample_norep!, 4:7, wv, zeros(Int, 2)) a = sample(4:7, wv, 3; replace=false, ordered=false) check_wsample_norep(a, (4, 7), wv, -1; ordered=false) +wv_sparse = weights(sparsevec([0, 8, 0, 6])) +a = sample(4:7, wv_sparse, 3; replace=false, ordered=false) +check_wsample_norep(a, (4, 7), wv_sparse, -1; ordered=false) + for rev in (true, false), T in (Int, Int16, Float64, Float16, BigInt, ComplexF64, Rational{Int}) r = rev ? reverse(4:7) : (4:7) r = T===Int ? r : T.(r) From 1f8f74da82817ffbb6e69da634b5d0b88c6cd5f1 Mon Sep 17 00:00:00 2001 From: Anton Oresten Date: Wed, 7 Jan 2026 17:16:28 +0100 Subject: [PATCH 5/5] fix test --- test/wsampling.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/wsampling.jl b/test/wsampling.jl index 8f0b3053d..89019a3e2 100644 --- a/test/wsampling.jl +++ b/test/wsampling.jl @@ -131,9 +131,9 @@ test_rng_use(efraimidis_aexpj_wsample_norep!, 4:7, wv, zeros(Int, 2)) a = sample(4:7, wv, 3; replace=false, ordered=false) check_wsample_norep(a, (4, 7), wv, -1; ordered=false) -wv_sparse = weights(sparsevec([0, 8, 0, 6])) -a = sample(4:7, wv_sparse, 3; replace=false, ordered=false) -check_wsample_norep(a, (4, 7), wv_sparse, -1; ordered=false) +wv_sparse = weights(sparsevec([0, 8, 4, 6])) +a_sparse = sample(4:7, wv_sparse, 3; replace=false, ordered=false) +check_wsample_norep(a_sparse, (4, 7), wv_sparse, -1; ordered=false) for rev in (true, false), T in (Int, Int16, Float64, Float16, BigInt, ComplexF64, Rational{Int}) r = rev ? reverse(4:7) : (4:7)