From 19a15df8961e85be7851f21c8aa92f3768bc615a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 26 Aug 2025 00:43:13 +0000 Subject: [PATCH 1/3] Initial plan From af565f94b2da5e04d92e3b48c64c21e86ace0c58 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 26 Aug 2025 01:01:37 +0000 Subject: [PATCH 2/3] Implement merge/merge! functions for AlgR and AlgL with hypergeometric sampling Co-authored-by: Tortar <68152031+Tortar@users.noreply.github.com> --- src/SamplingReduction.jl | 39 ++++++++++++++++++ src/UnweightedSamplingMulti.jl | 47 +++++++++++++++++----- test_merge_algr_algl.jl | 73 ++++++++++++++++++++++++++++++++++ 3 files changed, 149 insertions(+), 10 deletions(-) create mode 100644 test_merge_algr_algl.jl diff --git a/src/SamplingReduction.jl b/src/SamplingReduction.jl index b27e987..ade7cab 100644 --- a/src/SamplingReduction.jl +++ b/src/SamplingReduction.jl @@ -20,6 +20,37 @@ function reduce_samples(ps::AbstractArray, rngs, t::Union{TypeS,TypeUnion}, ss:: return reduce(vcat, v) end +function reduce_samples_hypergeometric(ps::AbstractArray, rngs, t::Union{TypeS,TypeUnion}, ss::AbstractArray...) + nt = length(ss) + v = Vector{Vector{get_type_rs(t, ss...)}}(undef, nt) + n = minimum(length.(ss)) + + # For hypergeometric sampling, we need to sample without replacement from finite populations + # The number of samples from each reservoir depends on hypergeometric distribution + # Total population size is sum of all reservoir sizes + total_pop = sum(length.(ss)) + + # Sample using hypergeometric distribution for each reservoir + ns = Vector{Int}(undef, nt) + remaining = n + remaining_pop = total_pop + + for i in 1:(nt-1) + pop_i = length(ss[i]) + # Use hypergeometric distribution: drawing `remaining` items from population `remaining_pop` + # where `pop_i` items are of the type we want + ns[i] = rand(extract_rng(rngs, 1), Hypergeometric(pop_i, remaining_pop - pop_i, remaining)) + remaining -= ns[i] + remaining_pop -= pop_i + end + ns[nt] = remaining # Remainder goes to last reservoir + + Threads.@threads for i in 1:nt + v[i] = sample(extract_rng(rngs, i), ss[i], ns[i]; replace = false) + end + return reduce(vcat, v) +end + extract_rng(v::AbstractArray, i) = v[i] extract_rng(v::AbstractRNG, i) = v @@ -31,6 +62,14 @@ function get_ps(ss::MultiAlgWRSWRSKIPSampler...) sum_w = sum(getfield(s, :state) for s in ss) return [s.state/sum_w for s in ss] end +function get_ps(ss::MultiAlgRSampler...) + sum_w = sum(getfield(s, :seen_k) for s in ss) + return [s.seen_k/sum_w for s in ss] +end +function get_ps(ss::MultiAlgLSampler...) + sum_w = sum(getfield(s, :seen_k) for s in ss) + return [s.seen_k/sum_w for s in ss] +end get_type_rs(::TypeS, s1::T, ss::T...) where {T} = eltype(s1) function get_type_rs(::TypeUnion, s1::T, ss::T...) where {T} diff --git a/src/UnweightedSamplingMulti.jl b/src/UnweightedSamplingMulti.jl index 3a0d2f5..af655ee 100644 --- a/src/UnweightedSamplingMulti.jl +++ b/src/UnweightedSamplingMulti.jl @@ -209,11 +209,19 @@ end is_ordered(s::MultiOrdAlgRSWRSKIPSampler) = true is_ordered(s::MultiAlgRSWRSKIPSampler) = false -function Base.merge(ss::MultiAlgRSampler...) - error("To Be Implemented") -end -function Base.merge(ss::MultiAlgLSampler...) - error("To Be Implemented") +function Base.merge(ss::MultiAlgRSampler...) + newvalue = reduce_samples_hypergeometric(get_ps(ss...), [s.rng for s in ss], TypeUnion(), value.(ss)...) + seen_k = sum(getfield(s, :seen_k) for s in ss) + n = minimum(s.n for s in ss) + return MultiAlgRSampler_Mut(n, seen_k, ss[1].rng, newvalue, nothing) +end +function Base.merge(ss::MultiAlgLSampler...) + newvalue = reduce_samples_hypergeometric(get_ps(ss...), [s.rng for s in ss], TypeUnion(), value.(ss)...) + seen_k = sum(getfield(s, :seen_k) for s in ss) + # For AlgL, we need to initialize state and skip_k appropriately + # state should be 0.0 for new merged sampler, skip_k should be 0 + n = minimum(s.n for s in ss) + return MultiAlgLSampler_Mut(n, 0.0, 0, seen_k, ss[1].rng, newvalue, nothing) end function Base.merge(ss::MultiAlgRSWRSKIPSampler...) newvalue = reduce_samples(get_ps(ss...), [s.rng for s in ss], TypeUnion(), value.(ss)...) @@ -223,11 +231,30 @@ function Base.merge(ss::MultiAlgRSWRSKIPSampler...) return MultiAlgRSWRSKIPSampler_Mut(n, skip_k, seen_k, ss[1].rng, newvalue, nothing) end -function Base.merge!(ss::MultiAlgRSampler...) - error("To Be Implemented") -end -function Base.merge!(ss::MultiAlgLSampler...) - error("To Be Implemented") +function Base.merge!(ss::MultiAlgRSampler...) + s1 = ss[1] + rest = ss[2:end] + s1.n > minimum(s.n for s in rest) && error("The size of the mutated reservoir should be the minimum size between all merged reservoir") + newvalue = reduce_samples_hypergeometric(get_ps(ss...), [s.rng for s in ss], TypeS(), value(s1), value.(rest)...) + for i in 1:length(newvalue) + @inbounds s1.value[i] = newvalue[i] + end + s1.seen_k += sum(getfield(s, :seen_k) for s in rest) + return s1 +end +function Base.merge!(ss::MultiAlgLSampler...) + s1 = ss[1] + rest = ss[2:end] + s1.n > minimum(s.n for s in rest) && error("The size of the mutated reservoir should be the minimum size between all merged reservoir") + newvalue = reduce_samples_hypergeometric(get_ps(ss...), [s.rng for s in ss], TypeS(), value(s1), value.(rest)...) + for i in 1:length(newvalue) + @inbounds s1.value[i] = newvalue[i] + end + s1.seen_k += sum(getfield(s, :seen_k) for s in rest) + # Reset state and skip_k for the merged sampler + s1.state = 0.0 + s1.skip_k = 0 + return s1 end function Base.merge!(s1::MultiAlgRSWRSKIPSampler{<:Nothing}, ss::MultiAlgRSWRSKIPSampler...) s1.n > minimum(s.n for s in ss) && error("The size of the mutated reservoir should be the minimum size between all merged reservoir") diff --git a/test_merge_algr_algl.jl b/test_merge_algr_algl.jl new file mode 100644 index 0000000..82ddb1f --- /dev/null +++ b/test_merge_algr_algl.jl @@ -0,0 +1,73 @@ +using StreamSampling +using Test +using Random + +@testset "AlgR and AlgL merge/merge! tests" begin + rng = Random.default_rng() + Random.seed!(rng, 43) + + # Test that merge functions don't error out + @testset "Basic merge functionality" begin + s1 = ReservoirSampler{Int}(rng, 2, AlgR()) + s2 = ReservoirSampler{Int}(rng, 2, AlgR()) + + # The merge should work even with empty samplers + merged = merge(s1, s2) + @test merged isa StreamSampling.MultiAlgRSampler_Mut + @test merged.n == 2 + + # Test merge! + s3 = ReservoirSampler{Int}(rng, 2, AlgR()) + s4 = ReservoirSampler{Int}(rng, 2, AlgR()) + result = merge!(s3, s4) + @test result === s3 # merge! should return the first sampler + @test s3.n == 2 + end + + @testset "AlgL merge functionality" begin + s1 = ReservoirSampler{Int}(rng, 2, AlgL()) + s2 = ReservoirSampler{Int}(rng, 2, AlgL()) + + merged = merge(s1, s2) + @test merged isa StreamSampling.MultiAlgLSampler_Mut + @test merged.n == 2 + @test merged.state == 0.0 # Should be reset + @test merged.skip_k == 0 # Should be reset + + # Test merge! + s3 = ReservoirSampler{Int}(rng, 2, AlgL()) + s4 = ReservoirSampler{Int}(rng, 2, AlgL()) + result = merge!(s3, s4) + @test result === s3 + @test s3.state == 0.0 + @test s3.skip_k == 0 + end + + # Test that merge functions preserve minimum n + @testset "Minimum n preservation" begin + s1 = ReservoirSampler{Int}(rng, 3, AlgR()) + s2 = ReservoirSampler{Int}(rng, 2, AlgR()) + merged = merge(s1, s2) + @test merged.n == 2 # Should take minimum + + s3 = ReservoirSampler{Int}(rng, 3, AlgL()) + s4 = ReservoirSampler{Int}(rng, 2, AlgL()) + merged = merge(s3, s4) + @test merged.n == 2 + end + + # Test merge! error conditions + @testset "merge! error conditions" begin + s1 = ReservoirSampler{Int}(rng, 3, AlgR()) # bigger + s2 = ReservoirSampler{Int}(rng, 2, AlgR()) # smaller + + @test_throws ErrorException merge!(s1, s2) # Should error because s1.n > s2.n + + s3 = ReservoirSampler{Int}(rng, 3, AlgL()) + s4 = ReservoirSampler{Int}(rng, 2, AlgL()) + + @test_throws ErrorException merge!(s3, s4) + end +end + +println("Basic merge tests completed successfully!") \ No newline at end of file From ecebbb29bdc16f649ae7dae311ef9d8cb65d42ce Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 26 Aug 2025 01:10:36 +0000 Subject: [PATCH 3/3] Complete implementation of merge/merge! for AlgR and AlgL with tests Co-authored-by: Tortar <68152031+Tortar@users.noreply.github.com> --- test/merge_tests.jl | 59 ++++++++++++++++++++++++++++++--- test_merge_algr_algl.jl | 73 ----------------------------------------- 2 files changed, 55 insertions(+), 77 deletions(-) delete mode 100644 test_merge_algr_algl.jl diff --git a/test/merge_tests.jl b/test/merge_tests.jl index d2fc637..ceb63fc 100644 --- a/test/merge_tests.jl +++ b/test/merge_tests.jl @@ -15,18 +15,54 @@ s_all = (s1, s2) for (s, it) in zip(s_all, iters) for x in it - m1 == AlgRSWRSKIP() ? fit!(s, x) : fit!(s, x, 1.0) + # Handle unweighted vs weighted algorithms + if m1 == AlgRSWRSKIP() + fit!(s, x) + else + fit!(s, x, 1.0) + end end end s_merged = merge(s1, s2) res[shuffle!(rng, value(s_merged))...] += 1 end - cases = (m1 == AlgRSWRSKIP() || m1 == AlgWRSWRSKIP()) ? 10^size : factorial(10)/factorial(10-size) + # Adjust expected number of cases for different algorithms + if m1 == AlgRSWRSKIP() || m1 == AlgWRSWRSKIP() + cases = 10^size + else + cases = factorial(10)/factorial(10-size) + end ps_exact = [1/cases for _ in 1:cases] count_est = [x for x in vec(res) if x != 0] chisq_test = ChisqTest(count_est, ps_exact) @test pvalue(chisq_test) > 0.05 end + + # Separate basic tests for AlgR and AlgL (not statistical) + @testset "AlgR and AlgL basic merge tests" begin + for m in (AlgR(), AlgL()) + s1 = ReservoirSampler{Int}(rng, size, m) + s2 = ReservoirSampler{Int}(rng, size, m) + + # Add some data + for x in 1:2; fit!(s1, x); end + for x in 3:4; fit!(s2, x); end + + # Test that merge works + merged = merge(s1, s2) + @test merged isa Union{StreamSampling.MultiAlgRSampler_Mut, StreamSampling.MultiAlgLSampler_Mut} + @test merged.n == size + + # Test that merge! works + s3 = ReservoirSampler{Int}(rng, size, m) + s4 = ReservoirSampler{Int}(rng, size, m) + for x in 5:6; fit!(s3, x); end + for x in 7:8; fit!(s4, x); end + + result = merge!(s3, s4) + @test result === s3 + end + end s1 = ReservoirSampler{Int}(rng, 2, AlgRSWRSKIP()) s2 = ReservoirSampler{Int}(rng, 2, AlgRSWRSKIP()) s_all = (s1, s2) @@ -39,8 +75,23 @@ for m in (AlgRSWRSKIP(), AlgWRSWRSKIP()) s1 = ReservoirSampler{Int}(rng, m) s2 = ReservoirSampler{Int}(rng, m) - m == AlgRSWRSKIP() ? fit!(s1, 1) : fit!(s1, 1, 1.0) - m == AlgRSWRSKIP() ? fit!(s2, 2) : fit!(s2, 2, 1.0) + if m == AlgRSWRSKIP() + fit!(s1, 1) + fit!(s2, 2) + else + fit!(s1, 1, 1.0) + fit!(s2, 2, 1.0) + end @test value(merge!(s1, s2)) in (1, 2) end + + # Test merge! for multi-element unweighted samplers (AlgR and AlgL) + for m in (AlgR(), AlgL()) + s1 = ReservoirSampler{Int}(rng, 1, m) # Single element reservoir + s2 = ReservoirSampler{Int}(rng, 1, m) + fit!(s1, 1) + fit!(s2, 2) + result = value(merge!(s1, s2)) + @test length(result) == 1 && result[1] in (1, 2) + end end diff --git a/test_merge_algr_algl.jl b/test_merge_algr_algl.jl deleted file mode 100644 index 82ddb1f..0000000 --- a/test_merge_algr_algl.jl +++ /dev/null @@ -1,73 +0,0 @@ -using StreamSampling -using Test -using Random - -@testset "AlgR and AlgL merge/merge! tests" begin - rng = Random.default_rng() - Random.seed!(rng, 43) - - # Test that merge functions don't error out - @testset "Basic merge functionality" begin - s1 = ReservoirSampler{Int}(rng, 2, AlgR()) - s2 = ReservoirSampler{Int}(rng, 2, AlgR()) - - # The merge should work even with empty samplers - merged = merge(s1, s2) - @test merged isa StreamSampling.MultiAlgRSampler_Mut - @test merged.n == 2 - - # Test merge! - s3 = ReservoirSampler{Int}(rng, 2, AlgR()) - s4 = ReservoirSampler{Int}(rng, 2, AlgR()) - result = merge!(s3, s4) - @test result === s3 # merge! should return the first sampler - @test s3.n == 2 - end - - @testset "AlgL merge functionality" begin - s1 = ReservoirSampler{Int}(rng, 2, AlgL()) - s2 = ReservoirSampler{Int}(rng, 2, AlgL()) - - merged = merge(s1, s2) - @test merged isa StreamSampling.MultiAlgLSampler_Mut - @test merged.n == 2 - @test merged.state == 0.0 # Should be reset - @test merged.skip_k == 0 # Should be reset - - # Test merge! - s3 = ReservoirSampler{Int}(rng, 2, AlgL()) - s4 = ReservoirSampler{Int}(rng, 2, AlgL()) - result = merge!(s3, s4) - @test result === s3 - @test s3.state == 0.0 - @test s3.skip_k == 0 - end - - # Test that merge functions preserve minimum n - @testset "Minimum n preservation" begin - s1 = ReservoirSampler{Int}(rng, 3, AlgR()) - s2 = ReservoirSampler{Int}(rng, 2, AlgR()) - merged = merge(s1, s2) - @test merged.n == 2 # Should take minimum - - s3 = ReservoirSampler{Int}(rng, 3, AlgL()) - s4 = ReservoirSampler{Int}(rng, 2, AlgL()) - merged = merge(s3, s4) - @test merged.n == 2 - end - - # Test merge! error conditions - @testset "merge! error conditions" begin - s1 = ReservoirSampler{Int}(rng, 3, AlgR()) # bigger - s2 = ReservoirSampler{Int}(rng, 2, AlgR()) # smaller - - @test_throws ErrorException merge!(s1, s2) # Should error because s1.n > s2.n - - s3 = ReservoirSampler{Int}(rng, 3, AlgL()) - s4 = ReservoirSampler{Int}(rng, 2, AlgL()) - - @test_throws ErrorException merge!(s3, s4) - end -end - -println("Basic merge tests completed successfully!") \ No newline at end of file