diff --git a/src/SamplingReduction.jl b/src/SamplingReduction.jl index b27e987..8875584 100644 --- a/src/SamplingReduction.jl +++ b/src/SamplingReduction.jl @@ -31,6 +31,16 @@ function get_ps(ss::MultiAlgWRSWRSKIPSampler...) sum_w = sum(getfield(s, :state) for s in ss) return [s.state/sum_w for s in ss] end +function get_ps(ss::MultiAlgRSampler...) + sum_w = sum(getfield(s, :seen_k) for s in ss) + sum_w == 0 && return [1.0/length(ss) for _ in ss] # Handle empty case + return [s.seen_k/sum_w for s in ss] +end +function get_ps(ss::MultiAlgLSampler...) + sum_w = sum(getfield(s, :seen_k) for s in ss) + sum_w == 0 && return [1.0/length(ss) for _ in ss] # Handle empty case + return [s.seen_k/sum_w for s in ss] +end get_type_rs(::TypeS, s1::T, ss::T...) where {T} = eltype(s1) function get_type_rs(::TypeUnion, s1::T, ss::T...) where {T} diff --git a/src/UnweightedSamplingMulti.jl b/src/UnweightedSamplingMulti.jl index 3a0d2f5..db8a4ef 100644 --- a/src/UnweightedSamplingMulti.jl +++ b/src/UnweightedSamplingMulti.jl @@ -209,11 +209,19 @@ end is_ordered(s::MultiOrdAlgRSWRSKIPSampler) = true is_ordered(s::MultiAlgRSWRSKIPSampler) = false -function Base.merge(ss::MultiAlgRSampler...) - error("To Be Implemented") -end -function Base.merge(ss::MultiAlgLSampler...) - error("To Be Implemented") +function Base.merge(ss::MultiAlgRSampler...) + newvalue = reduce_samples(get_ps(ss...), [s.rng for s in ss], TypeUnion(), value.(ss)...) + seen_k = sum(getfield(s, :seen_k) for s in ss) + n = minimum(s.n for s in ss) + return MultiAlgRSampler_Mut(n, seen_k, ss[1].rng, newvalue, nothing) +end +function Base.merge(ss::MultiAlgLSampler...) + newvalue = reduce_samples(get_ps(ss...), [s.rng for s in ss], TypeUnion(), value.(ss)...) + state = sum(getfield(s, :state) for s in ss) + skip_k = sum(getfield(s, :skip_k) for s in ss) + seen_k = sum(getfield(s, :seen_k) for s in ss) + n = minimum(s.n for s in ss) + return MultiAlgLSampler_Mut(n, state, skip_k, seen_k, ss[1].rng, newvalue, nothing) end function Base.merge(ss::MultiAlgRSWRSKIPSampler...) newvalue = reduce_samples(get_ps(ss...), [s.rng for s in ss], TypeUnion(), value.(ss)...) @@ -223,11 +231,31 @@ function Base.merge(ss::MultiAlgRSWRSKIPSampler...) return MultiAlgRSWRSKIPSampler_Mut(n, skip_k, seen_k, ss[1].rng, newvalue, nothing) end -function Base.merge!(ss::MultiAlgRSampler...) - error("To Be Implemented") -end -function Base.merge!(ss::MultiAlgLSampler...) - error("To Be Implemented") +function Base.merge!(ss::MultiAlgRSampler...) + error("To Be Implemented") +end +function Base.merge!(ss::MultiAlgLSampler...) + error("To Be Implemented") +end +function Base.merge!(s1::MultiAlgRSampler{<:Nothing}, ss::MultiAlgRSampler...) + s1.n > minimum(s.n for s in ss) && error("The size of the mutated reservoir should be the minimum size between all merged reservoir") + newvalue = reduce_samples(get_ps(s1, ss...), [s1.rng, [s.rng for s in ss]...], TypeS(), value(s1), value.(ss)...) + for i in 1:length(newvalue) + @inbounds s1.value[i] = newvalue[i] + end + s1.seen_k += sum(getfield(s, :seen_k) for s in ss) + return s1 +end +function Base.merge!(s1::MultiAlgLSampler{<:Nothing}, ss::MultiAlgLSampler...) + s1.n > minimum(s.n for s in ss) && error("The size of the mutated reservoir should be the minimum size between all merged reservoir") + newvalue = reduce_samples(get_ps(s1, ss...), [s1.rng, [s.rng for s in ss]...], TypeS(), value(s1), value.(ss)...) + for i in 1:length(newvalue) + @inbounds s1.value[i] = newvalue[i] + end + s1.state += sum(getfield(s, :state) for s in ss) + s1.skip_k += sum(getfield(s, :skip_k) for s in ss) + s1.seen_k += sum(getfield(s, :seen_k) for s in ss) + return s1 end function Base.merge!(s1::MultiAlgRSWRSKIPSampler{<:Nothing}, ss::MultiAlgRSWRSKIPSampler...) s1.n > minimum(s.n for s in ss) && error("The size of the mutated reservoir should be the minimum size between all merged reservoir") diff --git a/test/merge_tests.jl b/test/merge_tests.jl index d2fc637..ae6ab91 100644 --- a/test/merge_tests.jl +++ b/test/merge_tests.jl @@ -1,4 +1,6 @@ +using StreamSampling, StableRNGs, HypothesisTests, Random, Test + @testset "merge/merge! tests" begin rng = StableRNG(43) iters = (1:2, 3:10) @@ -27,6 +29,24 @@ chisq_test = ChisqTest(count_est, ps_exact) @test pvalue(chisq_test) > 0.05 end + + # Test basic merge functionality for AlgR and AlgL + for alg in [AlgR(), AlgL()] + s1 = ReservoirSampler{Int}(rng, 2, alg) + s2 = ReservoirSampler{Int}(rng, 2, alg) + + # Test empty merge + s_merged_empty = merge(s1, s2) + @test length(value(s_merged_empty)) == 0 + + # Test merge! for empty + s_copy = ReservoirSampler{Int}(rng, 2, alg) + s_other = ReservoirSampler{Int}(rng, 2, alg) + s_merged_empty_mut = merge!(s_copy, s_other) + @test s_merged_empty_mut === s_copy + @test length(value(s_merged_empty_mut)) == 0 + end + s1 = ReservoirSampler{Int}(rng, 2, AlgRSWRSKIP()) s2 = ReservoirSampler{Int}(rng, 2, AlgRSWRSKIP()) s_all = (s1, s2)