diff --git a/src/combinations.jl b/src/combinations.jl index bc8898d..60f449f 100644 --- a/src/combinations.jl +++ b/src/combinations.jl @@ -5,37 +5,33 @@ export combinations, powerset #The Combinations iterator - -struct Combinations{T} - a::T +struct Combinations + n::Int t::Int end -function Base.iterate(c::Combinations, s = collect(1:c.t)) - (!isempty(s) && s[1] > length(c.a) - c.t + 1) && return - - comb = [c.a[si] for si in s] - if c.t == 0 - # special case to generate 1 result for t==0 - return (comb, [length(c.a)+2]) +function Base.iterate(c::Combinations, s = [min(c.t - 1, i) for i in 1:c.t]) + if c.t == 0 # special case to generate 1 result for t==0 + isempty(s) && return (s, [1]) + return end - s = copy(s) - for i = length(s):-1:1 + for i in c.t:-1:1 s[i] += 1 - if s[i] > (length(c.a) - (length(s) - i)) + if s[i] > (c.n - (c.t - i)) continue end - for j = i+1:lastindex(s) - s[j] = s[j-1]+1 + for j in i+1:c.t + s[j] = s[j-1] + 1 end break end - (comb, s) + s[1] > c.n - c.t + 1 && return + (s, s) end -Base.length(c::Combinations) = binomial(length(c.a), c.t) +Base.length(c::Combinations) = binomial(c.n, c.t) -Base.eltype(::Type{Combinations{T}}) where {T} = Vector{eltype(T)} +Base.eltype(::Type{Combinations}) = Vector{Int} """ combinations(a, n) @@ -49,7 +45,8 @@ function combinations(a, t::Integer) # generate 0 combinations for negative argument t = length(a) + 1 end - Combinations(a, t) + reorder(c) = [a[ci] for ci in c] + (reorder(c) for c in Combinations(length(a), t)) end diff --git a/src/multinomials.jl b/src/multinomials.jl index 3f2dfd2..c3e5426 100644 --- a/src/multinomials.jl +++ b/src/multinomials.jl @@ -4,7 +4,7 @@ export multiexponents struct MultiExponents{T} - c::Combinations{T} + c::T nterms::Int end diff --git a/test/combinations.jl b/test/combinations.jl index ea5c655..2e7c74c 100644 --- a/test/combinations.jl +++ b/test/combinations.jl @@ -1,10 +1,10 @@ @test [combinations([])...] == [] -@test [combinations(['a', 'b', 'c'])...] == Any[['a'],['b'],['c'],['a','b'],['a','c'],['b','c'],['a','b','c']] +@test [combinations(['a', 'b', 'c'])...] == [['a'],['b'],['c'],['a','b'],['a','c'],['b','c'],['a','b','c']] -@test [combinations("abc",3)...] == Any[['a','b','c']] -@test [combinations("abc",2)...] == Any[['a','b'],['a','c'],['b','c']] -@test [combinations("abc",1)...] == Any[['a'],['b'],['c']] -@test [combinations("abc",0)...] == Any[[]] +@test [combinations("abc",3)...] == [['a','b','c']] +@test [combinations("abc",2)...] == [['a','b'],['a','c'],['b','c']] +@test [combinations("abc",1)...] == [['a'],['b'],['c']] +@test [combinations("abc",0)...] == [[]] @test [combinations("abc",-1)...] == [] @test filter(x->iseven(x[1]),[combinations([1,2,3],2)...]) == Any[[2,3]]