diff --git a/src/combinations.jl b/src/combinations.jl index ec466a9..9237f41 100644 --- a/src/combinations.jl +++ b/src/combinations.jl @@ -2,7 +2,8 @@ export combinations, CoolLexCombinations, multiset_combinations, with_replacement_combinations, - powerset + powerset, + nthcombo #The Combinations iterator struct Combinations @@ -273,3 +274,113 @@ function powerset(a, min::Integer=0, max::Integer=length(a)) min < 1 && append!(itrs, eltype(a)[]) Iterators.flatten(itrs) end + + +# Nth Combination + +""" + nthcombo(a, k::Int, n::Int) + +Compute the `n`th lexicographic k-combination of the vector `a`. + +# Examples +```jldoctest +julia> collect(combinations([1,2,3], 2)) +3-element Vector{Vector{Int64}}: + [1, 2] + [1, 3] + [2, 3] + +julia> nthcombo([1, 2, 3], 2, 1) +2-element Vector{Int64}: + 1 + 2 + +julia> nthcombo([1, 2, 3], 2, 2) +2-element Vector{Int64}: + 1 + 3 + +julia> nthcombo([1, 2, 3], 4, 2) +ERROR: ArgumentError: combination k must satisfy 0 ≤ k ≤ 3, got 4 +[...] +``` +""" +function nthcombo(a, k::Int, n::Int) + len = length(a) + 0 ≤ k ≤ len || throw(ArgumentError("combination k must satisfy 0 ≤ k ≤ $len, got $k")) + ncombos = binomial(len, k) + 0 < n ≤ ncombos || throw(ArgumentError("n must satisfy 0 < n ≤ $ncombos, got $n")) + (k == 0 || k == len) && return collect(a)[1:k] + + combo = eltype(a)[] + sizehint!(combo, k) + ncombos *= k + ncombos ÷= len + for i in eachindex(a) + if n ≤ ncombos + @inbounds push!(combo, a[i]) + isone(k) && return combo + k -= 1 + ncombos *= k + else + n -= ncombos + ncombos *= len - k + end + len -= 1 + ncombos ÷= len + end +end + +""" + nthcombo(a, c::Vector) + +Return the integer `n` that generated index-based lexicographic combination `c` from `a`. +Note that `nthcombo(a, nthcombo(a, k, n)) == n` for `1 ≤ n ≤ binomial(length(a), k)` and `unique(a) == a`. +In the case `unique(a) ≠ a`, returns the lowest `n` matching the combination, and +thus is not guaranteed to be the inverse of `nthcombo(a, k, n)`. + +# Examples +```jldoctest +julia> nthcombo([1:3...], nthcombo([1:3...], 2, 3)) +3 + +julia> collect(combinations([1, 2, 3], 2)) +3-element Vector{Vector{Int64}}: + [1, 2] + [1, 3] + [2, 3] + +julia> nthcombo([1, 2, 3], [1, 2]) +1 + +julia> nthcombo([1, 2, 3], [2, 3]) +3 +``` +""" +function nthcombo(a, combo::Vector) + isempty(combo) && return 1 + iscombo(a, combo) || throw(ArgumentError("$combo not a combination of $a")) + + aunique = unique(a) + idxmap = Dict(zip(aunique, 1:length(aunique))) + idxs = [idxmap[v] for v in combo] + ranges = collect(zip([0; idxs[1:end-1]] .+ 1, idxs .- 1)) + m, k = length(a), length(combo) + + n = 1 + for i in 1:k + lower, upper = ranges[i] + if upper - lower ≥ 0 + n += sum(binomial.(m .- collect(lower:upper), k - i)) + end + end + n +end + +function iscombo(a, combo) + counts = Dict{eltype(a), Int}() + foreach(key -> counts[key] = get(counts, key, 0) + 1, a) + foreach(key -> counts[key] = get(counts, key, 0) - 1, combo) + all(v -> v ≥ 0, (0, values(counts)...)) +end diff --git a/test/combinations.jl b/test/combinations.jl index f70a566..2e230e3 100644 --- a/test/combinations.jl +++ b/test/combinations.jl @@ -44,4 +44,24 @@ @test collect(powerset(['a', 'b', 'c'], 1)) == Any[['a'], ['b'], ['c'], ['a', 'b'], ['a', 'c'], ['b', 'c'], ['a', 'b', 'c']] @test collect(powerset(['a', 'b', 'c'], 1, 2)) == Any[['a'], ['b'], ['c'], ['a', 'b'], ['a', 'c'], ['b', 'c']] + # Nth Combo + @test nthcombo([1, 2, 3, 4], 0, 1) == [] + @test nthcombo([1, 2, 3, 4], 4, 1) == [1, 2, 3, 4] + @test nthcombo([1, 2, 3, 4], 3, 2) == [1, 2, 4] + @test all([nthcombo([1, 2, 3, 4], 2, n) for n in 1:binomial(4, 2)] .== collect(combinations([1, 2, 3, 4], 2))) + @test_throws ArgumentError nthcombo([1, 2, 3, 4], 0, 3) + @test_throws ArgumentError nthcombo([1, 2, 3, 4], 5, 3) + @test_throws ArgumentError nthcombo([1, 2, 3, 4], 2, 0) + @test_throws ArgumentError nthcombo([1, 2, 3], 2, 6) + + @test nthcombo([1, 2, 3, 4], []) == 1 + @test nthcombo([1, 2, 3, 4], [1, 2, 3, 4]) == 1 + @test nthcombo([1, 2, 3, 4], [1, 2, 4]) == 2 + @test [nthcombo(1:7, combo) for combo in combinations(1:7, 3)] == collect(1:binomial(7, 3)) + @test_throws ArgumentError nthcombo([1, 2, 3, 4], [1, 5]) + @test_throws ArgumentError nthcombo([1, 2, 3], [1, 2, 3, 3]) + + data = collect(1:7) + @test all([nthcombo(data, nthcombo(data, k, j)) == j for k in 1:7 for j in 1:binomial(7, k)]) + end