From 21dfef3f1656d952540f394145c24552cd37b0a2 Mon Sep 17 00:00:00 2001 From: denizyuret Date: Wed, 9 Jan 2019 18:30:45 -0500 Subject: [PATCH] fix #30643, correctly propagate iterator traits through Stateful (#30644) --- base/iterators.jl | 5 ++--- test/iterators.jl | 23 +++++++++++++++++++++++ 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/base/iterators.jl b/base/iterators.jl index b329a6c6f9eff..5d1a133a5ea4f 100644 --- a/base/iterators.jl +++ b/base/iterators.jl @@ -1089,10 +1089,9 @@ end @inline peek(s::Stateful, sentinel=nothing) = s.nextvalstate !== nothing ? s.nextvalstate[1] : sentinel @inline iterate(s::Stateful, state=nothing) = s.nextvalstate === nothing ? nothing : (popfirst!(s), nothing) -IteratorSize(::Type{Stateful{VS,T}} where VS) where {T} = - isa(IteratorSize(T), SizeUnknown) ? SizeUnknown() : HasLength() +IteratorSize(::Type{Stateful{T,VS}}) where {T,VS} = IteratorSize(T) isa HasShape ? HasLength() : IteratorSize(T) eltype(::Type{Stateful{T, VS}} where VS) where {T} = eltype(T) -IteratorEltype(::Type{Stateful{VS,T}} where VS) where {T} = IteratorEltype(T) +IteratorEltype(::Type{Stateful{T,VS}}) where {T,VS} = IteratorEltype(T) length(s::Stateful) = length(s.itr) - s.taken end diff --git a/test/iterators.jl b/test/iterators.jl index 86b2cd1d2969d..25a2df2ce8021 100644 --- a/test/iterators.jl +++ b/test/iterators.jl @@ -623,3 +623,26 @@ end @test @inferred(first(z)) == (1, "a", 1.0, 1, "a", 1.0, 1, "a", 1.0) @test @inferred(first(Iterators.drop(z, 1))) == (2, "b", 2.0, 2, "a", 1.2, 1, "c", 1.0) end + +@testset "Stateful fix #30643" begin + @test Base.IteratorSize(1:10) isa Base.HasShape + a = Iterators.Stateful(1:10) + @test Base.IteratorSize(a) isa Base.HasLength + @test length(a) == 10 + @test length(collect(a)) == 10 + @test length(a) == 0 + b = Iterators.Stateful(Iterators.take(1:10,3)) + @test Base.IteratorSize(b) isa Base.HasLength + @test length(b) == 3 + @test length(collect(b)) == 3 + @test length(b) == 0 + c = Iterators.Stateful(Iterators.countfrom(1)) + @test Base.IteratorSize(c) isa Base.IsInfinite + @test length(Iterators.take(c,3)) == 3 + @test length(collect(Iterators.take(c,3))) == 3 + d = Iterators.Stateful(Iterators.filter(isodd,1:10)) + @test Base.IteratorSize(d) isa Base.SizeUnknown + @test length(collect(Iterators.take(d,3))) == 3 + @test length(collect(d)) == 2 + @test length(collect(d)) == 0 +end