Skip to content
Open
7 changes: 7 additions & 0 deletions src/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,12 @@ function sample(rng::AbstractRNG, wv::AbstractWeights)
i += 1
cw += wv[i]
end
if cw < t
# may happen with floating point weights due to numerical inaccuracies
while iszero(wv[i])
i -= 1
end
end
return i
end
sample(wv::AbstractWeights) = sample(default_rng(), wv)
Expand Down Expand Up @@ -1063,3 +1069,4 @@ wsample(rng::AbstractRNG, a::AbstractArray{T}, w::AbstractVector{<:Real}, dims::
wsample(a::AbstractArray, w::AbstractVector{<:Real}, dims::Dims;
replace::Bool=true, ordered::Bool=false) =
wsample(default_rng(), a, w, dims; replace=replace, ordered=ordered)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

12 changes: 12 additions & 0 deletions test/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,18 @@ end
end
end

@testset "issue #982" begin
# Issue #982 was triggered in conjunction with SIMD support under certain circumstances,
# in particular with the following weights when sum(w) == 0.662413f0.
# The following mimics this case.
w = Float32[0.0437019, 0.04302464, 0.039748967, 0.040406376, 0.042578973,
0.040906563, 0.039586294, 0.04302464, 0.042357873, 0.04302464, 0.039262936,
0.040406376, 0.040406376, 0.041919112, 0.041484896, 0.04057242, 0.0]
Comment on lines +305 to +307
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
w = Float32[0.0437019, 0.04302464, 0.039748967, 0.040406376, 0.042578973,
0.040906563, 0.039586294, 0.04302464, 0.042357873, 0.04302464, 0.039262936,
0.040406376, 0.040406376, 0.041919112, 0.041484896, 0.04057242, 0.0]
w = Float32[0.0437019, 0.04302464, 0.039748967, 0.040406376, 0.042578973,
0.040906563, 0.039586294, 0.04302464, 0.042357873, 0.04302464, 0.039262936,
0.040406376, 0.040406376, 0.041919112, 0.041484896, 0.04057242, 0.0]

rng = StableRNG(889858990530)
s = sample(rng, Weights(w, 0.662413f0))
@test s == length(w) - 1
@test sample(rng, Weights([1, 2, 0, 0], 10000)) == 2 # another more trivial test
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@test sample(rng, Weights([1, 2, 0, 0], 10000)) == 2 # another more trivial test
# Artificial test with provided sum greater than actual sum
@test sample(rng, Weights([1, 2, 0, 0], 10000)) == 2

end
# Custom weights without `values` field
struct YAUnitWeights <: AbstractWeights{Int, Int, Vector{Int}}
n::Int
Expand Down