diff --git a/src/sampling.jl b/src/sampling.jl index c6294c979..dc1634fae 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -596,6 +596,12 @@ function sample(rng::AbstractRNG, wv::AbstractWeights) i += 1 cw += wv[i] end + if cw < t + # may happen with floating point weights due to numerical inaccuracies + while iszero(wv[i]) + i -= 1 + end + end return i end sample(wv::AbstractWeights) = sample(default_rng(), wv) @@ -1063,3 +1069,4 @@ wsample(rng::AbstractRNG, a::AbstractArray{T}, w::AbstractVector{<:Real}, dims:: wsample(a::AbstractArray, w::AbstractVector{<:Real}, dims::Dims; replace::Bool=true, ordered::Bool=false) = wsample(default_rng(), a, w, dims; replace=replace, ordered=ordered) + diff --git a/test/sampling.jl b/test/sampling.jl index a4f31a012..c30f5c2d2 100644 --- a/test/sampling.jl +++ b/test/sampling.jl @@ -298,6 +298,18 @@ end end end +@testset "issue #982" begin + # Issue #982 was triggered in conjunction with SIMD support under certain circumstances, + # in particular with the following weights when sum(w) == 0.662413f0. + # The following mimics this case. + w = Float32[0.0437019, 0.04302464, 0.039748967, 0.040406376, 0.042578973, + 0.040906563, 0.039586294, 0.04302464, 0.042357873, 0.04302464, 0.039262936, + 0.040406376, 0.040406376, 0.041919112, 0.041484896, 0.04057242, 0.0] + rng = StableRNG(889858990530) + s = sample(rng, Weights(w, 0.662413f0)) + @test s == length(w) - 1 + @test sample(rng, Weights([1, 2, 0, 0], 10000)) == 2 # another more trivial test +end # Custom weights without `values` field struct YAUnitWeights <: AbstractWeights{Int, Int, Vector{Int}} n::Int