Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/SamplingReduction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,16 @@ function get_ps(ss::MultiAlgWRSWRSKIPSampler...)
sum_w = sum(getfield(s, :state) for s in ss)
return [s.state/sum_w for s in ss]
end
function get_ps(ss::MultiAlgRSampler...)
sum_w = sum(getfield(s, :seen_k) for s in ss)
sum_w == 0 && return [1.0/length(ss) for _ in ss] # Handle empty case
return [s.seen_k/sum_w for s in ss]
end
function get_ps(ss::MultiAlgLSampler...)
sum_w = sum(getfield(s, :seen_k) for s in ss)
sum_w == 0 && return [1.0/length(ss) for _ in ss] # Handle empty case
return [s.seen_k/sum_w for s in ss]
end

get_type_rs(::TypeS, s1::T, ss::T...) where {T} = eltype(s1)
function get_type_rs(::TypeUnion, s1::T, ss::T...) where {T}
Expand Down
48 changes: 38 additions & 10 deletions src/UnweightedSamplingMulti.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,19 @@ end
is_ordered(s::MultiOrdAlgRSWRSKIPSampler) = true
is_ordered(s::MultiAlgRSWRSKIPSampler) = false

function Base.merge(ss::MultiAlgRSampler...)
error("To Be Implemented")
end
function Base.merge(ss::MultiAlgLSampler...)
error("To Be Implemented")
function Base.merge(ss::MultiAlgRSampler...)
newvalue = reduce_samples(get_ps(ss...), [s.rng for s in ss], TypeUnion(), value.(ss)...)
seen_k = sum(getfield(s, :seen_k) for s in ss)
n = minimum(s.n for s in ss)
return MultiAlgRSampler_Mut(n, seen_k, ss[1].rng, newvalue, nothing)
end
function Base.merge(ss::MultiAlgLSampler...)
newvalue = reduce_samples(get_ps(ss...), [s.rng for s in ss], TypeUnion(), value.(ss)...)
state = sum(getfield(s, :state) for s in ss)
skip_k = sum(getfield(s, :skip_k) for s in ss)
seen_k = sum(getfield(s, :seen_k) for s in ss)
n = minimum(s.n for s in ss)
return MultiAlgLSampler_Mut(n, state, skip_k, seen_k, ss[1].rng, newvalue, nothing)
end
function Base.merge(ss::MultiAlgRSWRSKIPSampler...)
newvalue = reduce_samples(get_ps(ss...), [s.rng for s in ss], TypeUnion(), value.(ss)...)
Expand All @@ -223,11 +231,31 @@ function Base.merge(ss::MultiAlgRSWRSKIPSampler...)
return MultiAlgRSWRSKIPSampler_Mut(n, skip_k, seen_k, ss[1].rng, newvalue, nothing)
end

function Base.merge!(ss::MultiAlgRSampler...)
error("To Be Implemented")
end
function Base.merge!(ss::MultiAlgLSampler...)
error("To Be Implemented")
function Base.merge!(ss::MultiAlgRSampler...)
error("To Be Implemented")
end
function Base.merge!(ss::MultiAlgLSampler...)
error("To Be Implemented")
end
function Base.merge!(s1::MultiAlgRSampler{<:Nothing}, ss::MultiAlgRSampler...)
s1.n > minimum(s.n for s in ss) && error("The size of the mutated reservoir should be the minimum size between all merged reservoir")
newvalue = reduce_samples(get_ps(s1, ss...), [s1.rng, [s.rng for s in ss]...], TypeS(), value(s1), value.(ss)...)
for i in 1:length(newvalue)
@inbounds s1.value[i] = newvalue[i]
end
s1.seen_k += sum(getfield(s, :seen_k) for s in ss)
return s1
end
function Base.merge!(s1::MultiAlgLSampler{<:Nothing}, ss::MultiAlgLSampler...)
s1.n > minimum(s.n for s in ss) && error("The size of the mutated reservoir should be the minimum size between all merged reservoir")
newvalue = reduce_samples(get_ps(s1, ss...), [s1.rng, [s.rng for s in ss]...], TypeS(), value(s1), value.(ss)...)
for i in 1:length(newvalue)
@inbounds s1.value[i] = newvalue[i]
end
s1.state += sum(getfield(s, :state) for s in ss)
s1.skip_k += sum(getfield(s, :skip_k) for s in ss)
s1.seen_k += sum(getfield(s, :seen_k) for s in ss)
return s1
end
function Base.merge!(s1::MultiAlgRSWRSKIPSampler{<:Nothing}, ss::MultiAlgRSWRSKIPSampler...)
s1.n > minimum(s.n for s in ss) && error("The size of the mutated reservoir should be the minimum size between all merged reservoir")
Expand Down
20 changes: 20 additions & 0 deletions test/merge_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@

using StreamSampling, StableRNGs, HypothesisTests, Random, Test

@testset "merge/merge! tests" begin
rng = StableRNG(43)
iters = (1:2, 3:10)
Expand Down Expand Up @@ -27,6 +29,24 @@
chisq_test = ChisqTest(count_est, ps_exact)
@test pvalue(chisq_test) > 0.05
end

# Test basic merge functionality for AlgR and AlgL
for alg in [AlgR(), AlgL()]
s1 = ReservoirSampler{Int}(rng, 2, alg)
s2 = ReservoirSampler{Int}(rng, 2, alg)

# Test empty merge
s_merged_empty = merge(s1, s2)
@test length(value(s_merged_empty)) == 0

# Test merge! for empty
s_copy = ReservoirSampler{Int}(rng, 2, alg)
s_other = ReservoirSampler{Int}(rng, 2, alg)
s_merged_empty_mut = merge!(s_copy, s_other)
@test s_merged_empty_mut === s_copy
@test length(value(s_merged_empty_mut)) == 0
end

s1 = ReservoirSampler{Int}(rng, 2, AlgRSWRSKIP())
s2 = ReservoirSampler{Int}(rng, 2, AlgRSWRSKIP())
s_all = (s1, s2)
Expand Down
Loading