Skip to content

Commit a6392c5

Browse files
denizyuretKristofferC
authored andcommitted
fix #30643, correctly propagate iterator traits through Stateful (#30644)
1 parent 2010bff commit a6392c5

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

base/iterators.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,10 +1094,9 @@ end
10941094

10951095
@inline peek(s::Stateful, sentinel=nothing) = s.nextvalstate !== nothing ? s.nextvalstate[1] : sentinel
10961096
@inline iterate(s::Stateful, state=nothing) = s.nextvalstate === nothing ? nothing : (popfirst!(s), nothing)
1097-
IteratorSize(::Type{Stateful{VS,T}} where VS) where {T} =
1098-
isa(IteratorSize(T), SizeUnknown) ? SizeUnknown() : HasLength()
1097+
IteratorSize(::Type{Stateful{T,VS}}) where {T,VS} = IteratorSize(T) isa HasShape ? HasLength() : IteratorSize(T)
10991098
eltype(::Type{Stateful{T, VS}} where VS) where {T} = eltype(T)
1100-
IteratorEltype(::Type{Stateful{VS,T}} where VS) where {T} = IteratorEltype(T)
1099+
IteratorEltype(::Type{Stateful{T,VS}}) where {T,VS} = IteratorEltype(T)
11011100
length(s::Stateful) = length(s.itr) - s.taken
11021101

11031102
end

test/iterators.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,3 +549,27 @@ end
549549
@test ps isa Iterators.Pairs
550550
@test collect(ps) == [1 => :a, 2 => :b]
551551
end
552+
553+
@testset "Stateful fix #30643" begin
554+
@test Base.IteratorSize(1:10) isa Base.HasShape
555+
a = Iterators.Stateful(1:10)
556+
@test Base.IteratorSize(a) isa Base.HasLength
557+
@test length(a) == 10
558+
@test length(collect(a)) == 10
559+
@test length(a) == 0
560+
b = Iterators.Stateful(Iterators.take(1:10,3))
561+
@test Base.IteratorSize(b) isa Base.HasLength
562+
@test length(b) == 3
563+
@test length(collect(b)) == 3
564+
@test length(b) == 0
565+
c = Iterators.Stateful(Iterators.countfrom(1))
566+
@test Base.IteratorSize(c) isa Base.IsInfinite
567+
@test length(Iterators.take(c,3)) == 3
568+
@test length(collect(Iterators.take(c,3))) == 3
569+
d = Iterators.Stateful(Iterators.filter(isodd,1:10))
570+
@test Base.IteratorSize(d) isa Base.SizeUnknown
571+
@test length(collect(Iterators.take(d,3))) == 3
572+
@test length(collect(d)) == 2
573+
@test length(collect(d)) == 0
574+
end
575+

0 commit comments

Comments
 (0)