Skip to content

Commit

Permalink
Tune Stateful iterator
Browse files Browse the repository at this point in the history
This attempts to address some of the performance regressions observed
with the Stateful iterator #25763. It gets most of the way there,
but unfortunately still ends up allocating the `Stateful` iterator
object rather than propagating through the fields. Getting the rest
of the way there will require some compiler tweaks.
  • Loading branch information
Keno committed Jan 29, 2018
1 parent 787550e commit 9c1b88a
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 31 deletions.
68 changes: 40 additions & 28 deletions base/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1016,44 +1016,56 @@ mutable struct Stateful{T, VS}
# A bit awkward right now, but adapted to the new iteration protocol
nextvalstate::Union{VS, Nothing}
taken::Int
# Try to find an appropriate type for the (value, state tuple),
# by doing a recursive unrolling of the iteration protocol up to
# fixpoint.
function fixpoint_iter_type(itrT::Type, valT::Type, stateT::Type)
nextvalstate = Base._return_type(next, Tuple{itrT, stateT})
nextvalstate <: Tuple{Any, Any} || return Any
nextvalstate = Tuple{
typejoin(valT, fieldtype(nextvalstate, 1)),
typejoin(stateT, fieldtype(nextvalstate, 2))}
return (Tuple{valT, stateT} == nextvalstate ? nextvalstate :
fixpoint_iter_type(itrT,
fieldtype(nextvalstate, 1),
fieldtype(nextvalstate, 2)))
end
function Stateful(itr::T) where {T}
@inline function Stateful(itr::T) where {T}
state = start(itr)
VS = fixpoint_iter_type(T, Union{}, typeof(state))
vs = done(itr, state) ? nothing : next(itr, state)::VS
new{T, VS}(itr, vs, 0)
if done(itr, state)
new{T, VS}(itr, nothing, 0)
else
new{T, VS}(itr, next(itr, state)::VS, 0)
end
end
end

# Try to find an appropriate type for the (value, state tuple),
# by doing a recursive unrolling of the iteration protocol up to
# fixpoint.
function fixpoint_iter_type(itrT::Type, valT::Type, stateT::Type)
nextvalstate = Base._return_type(next, Tuple{itrT, stateT})
nextvalstate <: Tuple{Any, Any} || return Any
nextvalstate = Tuple{
typejoin(valT, fieldtype(nextvalstate, 1)),
typejoin(stateT, fieldtype(nextvalstate, 2))}
return (Tuple{valT, stateT} == nextvalstate ? nextvalstate :
fixpoint_iter_type(itrT,
fieldtype(nextvalstate, 1),
fieldtype(nextvalstate, 2)))
end

convert(::Type{Stateful}, itr) = Stateful(itr)

isempty(s::Stateful) = s.nextvalstate === nothing
@inline isempty(s::Stateful) = s.nextvalstate === nothing

function popfirst!(s::Stateful)
isempty(s) && throw(EOFError())
val, state = s.nextvalstate
s.nextvalstate = done(s.itr, state) ? nothing : next(s.itr, state)
s.taken += 1
val
@inline function popfirst!(s::Stateful)
vs = s.nextvalstate
if vs === nothing
throw(EOFError())
else
val, state = vs
if done(s.itr, state)
s.nextvalstate = nothing
else
s.nextvalstate = next(s.itr, state)
end
s.taken += 1
return val
end
end

peek(s::Stateful, sentinel=nothing) = s.nextvalstate !== nothing ? s.nextvalstate[1] : sentinel
start(s::Stateful) = nothing
next(s::Stateful, state) = popfirst!(s), nothing
done(s::Stateful, state) = isempty(s)
@inline peek(s::Stateful, sentinel=nothing) = s.nextvalstate !== nothing ? s.nextvalstate[1] : sentinel
@inline start(s::Stateful) = nothing
@inline next(s::Stateful, state) = popfirst!(s), nothing
@inline done(s::Stateful, state) = isempty(s)
IteratorSize(::Type{Stateful{VS,T}} where VS) where {T} =
isa(IteratorSize(T), SizeUnknown) ? SizeUnknown() : HasLength()
eltype(::Type{Stateful{VS, T}} where VS) where {T} = eltype(T)
Expand Down
6 changes: 3 additions & 3 deletions test/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -495,15 +495,15 @@ end
end

@testset "Iterators.Stateful" begin
let a = Iterators.Stateful("abcdef")
let a = @inferred(Iterators.Stateful("abcdef"))
@test !isempty(a)
@test popfirst!(a) == 'a'
@test collect(Iterators.take(a, 3)) == ['b','c','d']
@test collect(a) == ['e', 'f']
end
let a = Iterators.Stateful([1, 1, 1, 2, 3, 4])
let a = @inferred(Iterators.Stateful([1, 1, 1, 2, 3, 4]))
for x in a; x == 1 || break; end
@test Base.peek(a) == 3
@test sum(a) == 7
end
end
end

0 comments on commit 9c1b88a

Please sign in to comment.