From 424bf6d9d893784b6110c156dad956bed72516c3 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 25 Aug 2025 23:47:39 +0000 Subject: [PATCH 1/2] Initial plan From c4a561d41b3fa4010658a72517b3feda8ef9c0be Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 26 Aug 2025 00:09:24 +0000 Subject: [PATCH 2/2] Implement merge and merge! functions for AlgR and AlgL algorithms Co-authored-by: Tortar <68152031+Tortar@users.noreply.github.com> --- src/SamplingReduction.jl | 10 +++++++ src/UnweightedSamplingMulti.jl | 48 +++++++++++++++++++++++++++------- test/merge_tests.jl | 20 ++++++++++++++ 3 files changed, 68 insertions(+), 10 deletions(-) diff --git a/src/SamplingReduction.jl b/src/SamplingReduction.jl index b27e987f..8875584e 100644 --- a/src/SamplingReduction.jl +++ b/src/SamplingReduction.jl @@ -31,6 +31,16 @@ function get_ps(ss::MultiAlgWRSWRSKIPSampler...) sum_w = sum(getfield(s, :state) for s in ss) return [s.state/sum_w for s in ss] end +function get_ps(ss::MultiAlgRSampler...) + sum_w = sum(getfield(s, :seen_k) for s in ss) + sum_w == 0 && return [1.0/length(ss) for _ in ss] # Handle empty case + return [s.seen_k/sum_w for s in ss] +end +function get_ps(ss::MultiAlgLSampler...) + sum_w = sum(getfield(s, :seen_k) for s in ss) + sum_w == 0 && return [1.0/length(ss) for _ in ss] # Handle empty case + return [s.seen_k/sum_w for s in ss] +end get_type_rs(::TypeS, s1::T, ss::T...) where {T} = eltype(s1) function get_type_rs(::TypeUnion, s1::T, ss::T...) where {T} diff --git a/src/UnweightedSamplingMulti.jl b/src/UnweightedSamplingMulti.jl index 3a0d2f58..db8a4efe 100644 --- a/src/UnweightedSamplingMulti.jl +++ b/src/UnweightedSamplingMulti.jl @@ -209,11 +209,19 @@ end is_ordered(s::MultiOrdAlgRSWRSKIPSampler) = true is_ordered(s::MultiAlgRSWRSKIPSampler) = false -function Base.merge(ss::MultiAlgRSampler...) - error("To Be Implemented") -end -function Base.merge(ss::MultiAlgLSampler...) - error("To Be Implemented") +function Base.merge(ss::MultiAlgRSampler...) + newvalue = reduce_samples(get_ps(ss...), [s.rng for s in ss], TypeUnion(), value.(ss)...) + seen_k = sum(getfield(s, :seen_k) for s in ss) + n = minimum(s.n for s in ss) + return MultiAlgRSampler_Mut(n, seen_k, ss[1].rng, newvalue, nothing) +end +function Base.merge(ss::MultiAlgLSampler...) + newvalue = reduce_samples(get_ps(ss...), [s.rng for s in ss], TypeUnion(), value.(ss)...) + state = sum(getfield(s, :state) for s in ss) + skip_k = sum(getfield(s, :skip_k) for s in ss) + seen_k = sum(getfield(s, :seen_k) for s in ss) + n = minimum(s.n for s in ss) + return MultiAlgLSampler_Mut(n, state, skip_k, seen_k, ss[1].rng, newvalue, nothing) end function Base.merge(ss::MultiAlgRSWRSKIPSampler...) newvalue = reduce_samples(get_ps(ss...), [s.rng for s in ss], TypeUnion(), value.(ss)...) @@ -223,11 +231,31 @@ function Base.merge(ss::MultiAlgRSWRSKIPSampler...) return MultiAlgRSWRSKIPSampler_Mut(n, skip_k, seen_k, ss[1].rng, newvalue, nothing) end -function Base.merge!(ss::MultiAlgRSampler...) - error("To Be Implemented") -end -function Base.merge!(ss::MultiAlgLSampler...) - error("To Be Implemented") +function Base.merge!(ss::MultiAlgRSampler...) + error("To Be Implemented") +end +function Base.merge!(ss::MultiAlgLSampler...) + error("To Be Implemented") +end +function Base.merge!(s1::MultiAlgRSampler{<:Nothing}, ss::MultiAlgRSampler...) + s1.n > minimum(s.n for s in ss) && error("The size of the mutated reservoir should be the minimum size between all merged reservoir") + newvalue = reduce_samples(get_ps(s1, ss...), [s1.rng, [s.rng for s in ss]...], TypeS(), value(s1), value.(ss)...) + for i in 1:length(newvalue) + @inbounds s1.value[i] = newvalue[i] + end + s1.seen_k += sum(getfield(s, :seen_k) for s in ss) + return s1 +end +function Base.merge!(s1::MultiAlgLSampler{<:Nothing}, ss::MultiAlgLSampler...) + s1.n > minimum(s.n for s in ss) && error("The size of the mutated reservoir should be the minimum size between all merged reservoir") + newvalue = reduce_samples(get_ps(s1, ss...), [s1.rng, [s.rng for s in ss]...], TypeS(), value(s1), value.(ss)...) + for i in 1:length(newvalue) + @inbounds s1.value[i] = newvalue[i] + end + s1.state += sum(getfield(s, :state) for s in ss) + s1.skip_k += sum(getfield(s, :skip_k) for s in ss) + s1.seen_k += sum(getfield(s, :seen_k) for s in ss) + return s1 end function Base.merge!(s1::MultiAlgRSWRSKIPSampler{<:Nothing}, ss::MultiAlgRSWRSKIPSampler...) s1.n > minimum(s.n for s in ss) && error("The size of the mutated reservoir should be the minimum size between all merged reservoir") diff --git a/test/merge_tests.jl b/test/merge_tests.jl index d2fc6371..ae6ab913 100644 --- a/test/merge_tests.jl +++ b/test/merge_tests.jl @@ -1,4 +1,6 @@ +using StreamSampling, StableRNGs, HypothesisTests, Random, Test + @testset "merge/merge! tests" begin rng = StableRNG(43) iters = (1:2, 3:10) @@ -27,6 +29,24 @@ chisq_test = ChisqTest(count_est, ps_exact) @test pvalue(chisq_test) > 0.05 end + + # Test basic merge functionality for AlgR and AlgL + for alg in [AlgR(), AlgL()] + s1 = ReservoirSampler{Int}(rng, 2, alg) + s2 = ReservoirSampler{Int}(rng, 2, alg) + + # Test empty merge + s_merged_empty = merge(s1, s2) + @test length(value(s_merged_empty)) == 0 + + # Test merge! for empty + s_copy = ReservoirSampler{Int}(rng, 2, alg) + s_other = ReservoirSampler{Int}(rng, 2, alg) + s_merged_empty_mut = merge!(s_copy, s_other) + @test s_merged_empty_mut === s_copy + @test length(value(s_merged_empty_mut)) == 0 + end + s1 = ReservoirSampler{Int}(rng, 2, AlgRSWRSKIP()) s2 = ReservoirSampler{Int}(rng, 2, AlgRSWRSKIP()) s_all = (s1, s2)