diff --git a/base/iterators.jl b/base/iterators.jl index fcf04c358fa9e..867120f6cd5f1 100644 --- a/base/iterators.jl +++ b/base/iterators.jl @@ -1016,44 +1016,56 @@ mutable struct Stateful{T, VS} # A bit awkward right now, but adapted to the new iteration protocol nextvalstate::Union{VS, Nothing} taken::Int - # Try to find an appropriate type for the (value, state tuple), - # by doing a recursive unrolling of the iteration protocol up to - # fixpoint. - function fixpoint_iter_type(itrT::Type, valT::Type, stateT::Type) - nextvalstate = Base._return_type(next, Tuple{itrT, stateT}) - nextvalstate <: Tuple{Any, Any} || return Any - nextvalstate = Tuple{ - typejoin(valT, fieldtype(nextvalstate, 1)), - typejoin(stateT, fieldtype(nextvalstate, 2))} - return (Tuple{valT, stateT} == nextvalstate ? nextvalstate : - fixpoint_iter_type(itrT, - fieldtype(nextvalstate, 1), - fieldtype(nextvalstate, 2))) - end - function Stateful(itr::T) where {T} + @inline function Stateful(itr::T) where {T} state = start(itr) VS = fixpoint_iter_type(T, Union{}, typeof(state)) - vs = done(itr, state) ? nothing : next(itr, state)::VS - new{T, VS}(itr, vs, 0) + if done(itr, state) + new{T, VS}(itr, nothing, 0) + else + new{T, VS}(itr, next(itr, state)::VS, 0) + end end end +# Try to find an appropriate type for the (value, state tuple), +# by doing a recursive unrolling of the iteration protocol up to +# fixpoint. +function fixpoint_iter_type(itrT::Type, valT::Type, stateT::Type) + nextvalstate = Base._return_type(next, Tuple{itrT, stateT}) + nextvalstate <: Tuple{Any, Any} || return Any + nextvalstate = Tuple{ + typejoin(valT, fieldtype(nextvalstate, 1)), + typejoin(stateT, fieldtype(nextvalstate, 2))} + return (Tuple{valT, stateT} == nextvalstate ? nextvalstate : + fixpoint_iter_type(itrT, + fieldtype(nextvalstate, 1), + fieldtype(nextvalstate, 2))) +end + convert(::Type{Stateful}, itr) = Stateful(itr) -isempty(s::Stateful) = s.nextvalstate === nothing +@inline isempty(s::Stateful) = s.nextvalstate === nothing -function popfirst!(s::Stateful) - isempty(s) && throw(EOFError()) - val, state = s.nextvalstate - s.nextvalstate = done(s.itr, state) ? nothing : next(s.itr, state) - s.taken += 1 - val +@inline function popfirst!(s::Stateful) + vs = s.nextvalstate + if vs === nothing + throw(EOFError()) + else + val, state = vs + if done(s.itr, state) + s.nextvalstate = nothing + else + s.nextvalstate = next(s.itr, state) + end + s.taken += 1 + return val + end end -peek(s::Stateful, sentinel=nothing) = s.nextvalstate !== nothing ? s.nextvalstate[1] : sentinel -start(s::Stateful) = nothing -next(s::Stateful, state) = popfirst!(s), nothing -done(s::Stateful, state) = isempty(s) +@inline peek(s::Stateful, sentinel=nothing) = s.nextvalstate !== nothing ? s.nextvalstate[1] : sentinel +@inline start(s::Stateful) = nothing +@inline next(s::Stateful, state) = popfirst!(s), nothing +@inline done(s::Stateful, state) = isempty(s) IteratorSize(::Type{Stateful{VS,T}} where VS) where {T} = isa(IteratorSize(T), SizeUnknown) ? SizeUnknown() : HasLength() eltype(::Type{Stateful{VS, T}} where VS) where {T} = eltype(T) diff --git a/test/iterators.jl b/test/iterators.jl index 19efa5b31b31c..634a32f923149 100644 --- a/test/iterators.jl +++ b/test/iterators.jl @@ -495,15 +495,15 @@ end end @testset "Iterators.Stateful" begin - let a = Iterators.Stateful("abcdef") + let a = @inferred(Iterators.Stateful("abcdef")) @test !isempty(a) @test popfirst!(a) == 'a' @test collect(Iterators.take(a, 3)) == ['b','c','d'] @test collect(a) == ['e', 'f'] end - let a = Iterators.Stateful([1, 1, 1, 2, 3, 4]) + let a = @inferred(Iterators.Stateful([1, 1, 1, 2, 3, 4])) for x in a; x == 1 || break; end @test Base.peek(a) == 3 @test sum(a) == 7 end -end \ No newline at end of file +end