Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: add IteratorAccess trait: for indexing capabilities #22489

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions base/generator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,50 @@ Base.HasEltype()
iteratoreltype(x) = iteratoreltype(typeof(x))
iteratoreltype(::Type) = HasEltype() # HasEltype is the default

abstract type IteratorAccess end
struct ForwardAccess <: IteratorAccess end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will folks want ReverseAccess? In other words, is this intended to collect all supported/efficient ways of iterating over an object? If so we may need to think about traits for iterator arithmetic, etc.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea was indeed to also support ReverseAccess (although I can't read my mind from almost 3y ago!) and this was definitely inspired by the c++ traits. Cf. my comment, ReverseAccess is probably "orthogonal" to the three access patterns currently included here.

struct RandomAccess <: IteratorAccess end
struct WritableRandomAccess <: IteratorAccess end

"""
iteratoraccess(itertype::Type) -> IteratorAccess

Given the type of an iterator, returns one of the following values:

* `ForwardAccess()` if the iterator can be iterated over.
* `RandomAccess()` if the iterator supports read-only indexing.
* `WritableRandomAccess()` if the iterator supports read-write indexing.

The default value (for iterators that do not define this function) is `ForwardAccess()`.

```jldoctest
julia> Base.iteratoraccess(1:5)
Base.RandomAccess()

julia> Base.iteratoraccess([1, 2, 3])
Base.WritableRandomAccess()

julia> Base.iteratoraccess(drop([1, 2, 3], 1))
Base.WritableRandomAccess()
```
"""
iteratoraccess(x) = iteratoraccess(typeof(x))
Copy link
Contributor

@jw3126 jw3126 Jul 1, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could spell iteratoraccess as IteratorAccess instead. This would be consistent with some other recent traits. Personally I think it is a good idea to just use the abstract type instead of inventing a new function. Not sure if there is an official recommendation for one or the other though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This follows iteratorsize etc.. I could be changed for all iterator traits at once in a different PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Iteratortraits are also spelled in camel case by now, #25402

iteratoraccess(::Type) = ForwardAccess() # ForwardAccess is the default

removewritable(::ForwardAccess) = ForwardAccess()
removewritable(::Union{RandomAccess,WritableRandomAccess}) = RandomAccess()

iteratorsize(::Type{<:AbstractArray}) = HasShape()
iteratorsize(::Type{Generator{I,F}}) where {I,F} = iteratorsize(I)

iteratoraccess(::Type{<:AbstractArray}) = RandomAccess()
iteratoraccess(::Type{<:Array}) = WritableRandomAccess()
iteratoraccess(::Type{Generator{I,F}}) where{I,F} = removewritable(iteratoraccess(I))

length(g::Generator) = length(g.iter)
size(g::Generator) = size(g.iter)
indices(g::Generator) = indices(g.iter)
ndims(g::Generator) = ndims(g.iter)
getindex(g::Generator, key...) = map(g.f, g.iter[key...])

iteratoreltype(::Type{Generator{I,T}}) where {I,T} = EltypeUnknown()
71 changes: 68 additions & 3 deletions base/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

module Iterators

import Base: start, done, next, isempty, length, size, eltype, iteratorsize, iteratoreltype, indices, ndims
import Base: start, done, next, isempty, length, size, eltype, iteratorsize, iteratoreltype, iteratoraccess, indices, ndims, getindex, setindex!

using Base: tuple_type_cons, SizeUnknown, HasLength, HasShape, IsInfinite, EltypeUnknown, HasEltype, OneTo, @propagate_inbounds
using Base: tuple_type_cons, SizeUnknown, HasLength, HasShape, IsInfinite, EltypeUnknown, HasEltype,
ForwardAccess, RandomAccess, WritableRandomAccess, removewritable, OneTo, @propagate_inbounds

export enumerate, zip, rest, countfrom, take, drop, cycle, repeated, product, flatten, partition

Expand All @@ -26,6 +27,12 @@ and_iteratorsize(a, b) = SizeUnknown()
and_iteratoreltype(iel::T, ::T) where {T} = iel
and_iteratoreltype(a, b) = EltypeUnknown()

and_iteratoraccess(::ForwardAccess, b) = ForwardAccess()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
and_iteratoraccess(::ForwardAccess, b) = ForwardAccess()
and_iteratoraccess(::ForwardAccess, b::IteratorAccess) = ForwardAccess()

and similarly elsewhere.

and_iteratoraccess(a, ::ForwardAccess) = ForwardAccess()
and_iteratoraccess(::RandomAccess, b) = RandomAccess()
and_iteratoraccess(a, ::RandomAccess) = RandomAccess()
and_iteratoraccess(::WritableRandomAccess, ::WritableRandomAccess) = WritableRandomAccess()

# enumerate

struct Enumerate{I}
Expand Down Expand Up @@ -67,8 +74,11 @@ end

eltype(::Type{Enumerate{I}}) where {I} = Tuple{Int, eltype(I)}

getindex(e::Enumerate, key...) = getindex(zip(1:length(e), e.itr), key...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the codegen here OK? And why the splat? This is a one-dimensional container.


iteratorsize(::Type{Enumerate{I}}) where {I} = iteratorsize(I)
iteratoreltype(::Type{Enumerate{I}}) where {I} = iteratoreltype(I)
iteratoraccess(::Type{Enumerate{I}}) where {I} = removewritable(iteratoreltype(I))

struct IndexValue{I,A<:AbstractArray}
data::A
Expand Down Expand Up @@ -166,8 +176,12 @@ eltype(::Type{Zip1{I}}) where {I} = Tuple{eltype(I)}
end
@inline done(z::Zip1, st) = done(z.a,st)

getindex(z::Zip1, key...) = (z.a[key...],)
setindex!(z::Zip1, value::Tuple{I}, key...) where {I} = z.a[key...] = value[1]

iteratorsize(::Type{Zip1{I}}) where {I} = iteratorsize(I)
iteratoreltype(::Type{Zip1{I}}) where {I} = iteratoreltype(I)
iteratoraccess(::Type{Zip1{I}}) where {I} = iteratoreltype(I)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
iteratoraccess(::Type{Zip1{I}}) where {I} = iteratoreltype(I)
iteratoraccess(::Type{Zip1{I}}) where {I} = IteratorAccess(I)


struct Zip2{I1, I2} <: AbstractZipIterator
a::I1
Expand All @@ -186,8 +200,13 @@ eltype(::Type{Zip2{I1,I2}}) where {I1,I2} = Tuple{eltype(I1), eltype(I2)}
end
@inline done(z::Zip2, st) = done(z.a,st[1]) | done(z.b,st[2])

getindex(z::Zip2, key...) = (z.a[key...], z.b[key...])
setindex!(z::Zip2, value::Tuple{I1,I2}, key...) where {I1,I2} = (z.a[key...] = value[1];
z.b[key...] = value[2])

iteratorsize(::Type{Zip2{I1,I2}}) where {I1,I2} = zip_iteratorsize(iteratorsize(I1),iteratorsize(I2))
iteratoreltype(::Type{Zip2{I1,I2}}) where {I1,I2} = and_iteratoreltype(iteratoreltype(I1),iteratoreltype(I2))
iteratoraccess(::Type{Zip2{I1,I2}}) where {I1,I2} = and_iteratoraccess(iteratoraccess(I1),iteratoraccess(I2))

struct Zip{I, Z<:AbstractZipIterator} <: AbstractZipIterator
a::I
Expand Down Expand Up @@ -237,8 +256,17 @@ eltype(::Type{Zip{I,Z}}) where {I,Z} = tuple_type_cons(eltype(I), eltype(Z))
end
@inline done(z::Zip, st) = done(z.a,st[1]) | done(z.z,st[2])

getindex(z::Zip, key...) = tuple(z.a[key...], z.z[key...]...)
setindex!(z::Zip, value::Tuple, key...) = _setindex!(z, value, 1, key)
_setindex!(z::Zip, value, n::Int, key) = (z.a[key...] = value[n];
_setindex!(z.z, value, n+1, key))
_setindex!(z::Zip2, value, n::Int, key) = (z.a[key...] = value[n];
z.b[key...] = value[n+1])

iteratorsize(::Type{Zip{I1,I2}}) where {I1,I2} = zip_iteratorsize(iteratorsize(I1),iteratorsize(I2))
iteratoreltype(::Type{Zip{I1,I2}}) where {I1,I2} = and_iteratoreltype(iteratoreltype(I1),iteratoreltype(I2))
iteratoraccess(::Type{Zip{I1,I2}}) where {I1,I2} = and_iteratoraccess(iteratoraccess(I1),iteratoraccess(I2))


# filter

Expand Down Expand Up @@ -343,7 +371,10 @@ start(it::Count) = it.start
next(it::Count, state) = (state, state + it.step)
done(it::Count, state) = false

getindex(it::Count, key::Integer) = it.start + (key-1)*it.step

iteratorsize(::Type{<:Count}) = IsInfinite()
iteratoraccess(::Type{<:Count}) = RandomAccess()

# Take -- iterate through the first n elements

Expand Down Expand Up @@ -387,6 +418,8 @@ take_iteratorsize(::SizeUnknown) = SizeUnknown()
iteratorsize(::Type{Take{I}}) where {I} = take_iteratorsize(iteratorsize(I))
length(t::Take) = _min_length(t.xs, 1:t.n, iteratorsize(t.xs), HasLength())

iteratoraccess(::Type{Take{I}}) where {I} = iteratoraccess(I)

start(it::Take) = (it.n, start(it.xs))

function next(it::Take, state)
Expand All @@ -400,6 +433,10 @@ function done(it::Take, state)
return n <= 0 || done(it.xs, xs_state)
end

getindex(it::Take, key::Integer) = key > it.n ? throw(BoundsError(it, key)) : it.xs[key]
setindex!(it::Take, value, key::Integer) =
key > it.n ? throw(BoundsError(it, key)) : it.xs[key] = value

# Drop -- iterator through all but the first n elements

struct Drop{I}
Expand Down Expand Up @@ -443,6 +480,8 @@ drop_iteratorsize(::IsInfinite) = IsInfinite()
iteratorsize(::Type{Drop{I}}) where {I} = drop_iteratorsize(iteratorsize(I))
length(d::Drop) = _diff_length(d.xs, 1:d.n, iteratorsize(d.xs), HasLength())

iteratoraccess(::Type{Drop{I}}) where {I} = iteratoraccess(I)

function start(it::Drop)
xs_state = start(it.xs)
for i in 1:it.n
Expand All @@ -458,6 +497,9 @@ end
next(it::Drop, state) = next(it.xs, state)
done(it::Drop, state) = done(it.xs, state)

getindex(it::Drop, key::Integer) = it.xs[key+it.n]
setindex!(it::Drop, value, key::Integer) = it.xs[key+it.n] = value

# Cycle an iterator forever

struct Cycle{I}
Expand All @@ -474,6 +516,10 @@ cycle(xs) = Cycle(xs)
eltype(::Type{Cycle{I}}) where {I} = eltype(I)
iteratoreltype(::Type{Cycle{I}}) where {I} = iteratoreltype(I)
iteratorsize(::Type{Cycle{I}}) where {I} = IsInfinite()
iteratoraccess(::Type{Cycle{I}}) where {I} = cycle_iteratoraccess(I, iteratorsize(I))

cycle_iteratoraccess(I, _) = iteratoraccess(I)
cycle_iteratoraccess(I, ::SizeUnknown) = ForwardAccess()

function start(it::Cycle)
s = start(it.xs)
Expand All @@ -491,6 +537,16 @@ end

done(it::Cycle, state) = state[2]

getindex(it::Cycle, key::Integer) = _getindex(it, key, iteratorsize(I))
_getindex(it::Cycle, key::Integer, ::Union{HasLength,HasShape}) =
it.xs[mod1(key, length(it.xs))]
_getindex(it::Cycle, key::Integer, ::IsInfinite) = it.xs[key]

setindex!(it::Cycle, value, key::Integer) = _setindex!(it, value, key, iteratorsize(I))
_setindex!(it::Cycle, value, key::Integer, ::Union{HasLength,HasShape}) =
it.xs[mod1(key, length(it.xs))] = value
_setindex!(it::Cycle, value, key::Integer, ::IsInfinite) = it.xs[key] = value


# Repeated - repeat an object infinitely many times

Expand Down Expand Up @@ -524,9 +580,11 @@ start(it::Repeated) = nothing
next(it::Repeated, state) = (it.x, nothing)
done(it::Repeated, state) = false

getindex(it::Repeated, key::Integer) = it.x

iteratorsize(::Type{<:Repeated}) = IsInfinite()
iteratoreltype(::Type{<:Repeated}) = HasEltype()

iteratoraccess(::Type{<:Repeated}) = RandomAccess()

# Product -- cartesian product of iterators

Expand Down Expand Up @@ -578,8 +636,12 @@ indices(p::Prod1) = _prod_indices(p.a, iteratorsize(p.a))
end
@inline done(p::Prod1, st) = done(p.a, st)

getindex(p::Prod1, key...) = (p.a[key...],)
setindex!(p::Prod1, value::Tuple{T}, key...) where {T} = p.a[key...] = value[1]

iteratoreltype(::Type{Prod1{I}}) where {I} = iteratoreltype(I)
iteratorsize(::Type{Prod1{I}}) where {I} = iteratorsize(I)
iteratoraccess(::Type{Prod1{I}}) where {I} = iteratoraccess(I)

# two iterators
struct Prod2{I1, I2} <: AbstractProdIterator
Expand Down Expand Up @@ -607,6 +669,7 @@ eltype(::Type{Prod2{I1,I2}}) where {I1,I2} = Tuple{eltype(I1), eltype(I2)}

iteratoreltype(::Type{Prod2{I1,I2}}) where {I1,I2} = and_iteratoreltype(iteratoreltype(I1),iteratoreltype(I2))
iteratorsize(::Type{Prod2{I1,I2}}) where {I1,I2} = prod_iteratorsize(iteratorsize(I1),iteratorsize(I2))
iteratoraccess(::Type{Prod2{I1,I2}}) where {I1,I2} = and_iteratoraccess(iteratoraccess(I1), iteratoraccess(I2))

function start(p::AbstractProdIterator)
s1, s2 = start(p.a), start(p.b)
Expand All @@ -633,6 +696,8 @@ end
@inline next(p::Prod2, st) = prod_next(p, st)
@inline done(p::AbstractProdIterator, st) = st[4]

# TODO: write getindex/setindex! for Prod2 & Prod

# n iterators
struct Prod{I1, I2<:AbstractProdIterator} <: AbstractProdIterator
a::I1
Expand Down
31 changes: 18 additions & 13 deletions base/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

module Sort

using Base: Order, Checked, copymutable, linearindices, IndexStyle, viewindexing, IndexLinear, _length
using Base: Order, Checked, copymutable, linearindices, IndexStyle, viewindexing, IndexLinear, _length, WritableRandomAccess

import
Base.sort,
Expand Down Expand Up @@ -244,7 +244,7 @@ const DEFAULT_STABLE = MergeSort
const SMALL_ALGORITHM = InsertionSort
const SMALL_THRESHOLD = 20

function sort!(v::AbstractVector, lo::Int, hi::Int, ::InsertionSortAlg, o::Ordering)
function sort!(v, lo::Int, hi::Int, ::InsertionSortAlg, o::Ordering)
@inbounds for i = lo+1:hi
j = i
x = v[i]
Expand All @@ -269,7 +269,7 @@ end
# Upon return, the pivot is in v[lo], and v[hi] is guaranteed to be
# greater than the pivot

@inline function selectpivot!(v::AbstractVector, lo::Int, hi::Int, o::Ordering)
@inline function selectpivot!(v, lo::Int, hi::Int, o::Ordering)
@inbounds begin
mi = (lo+hi)>>>1

Expand Down Expand Up @@ -299,7 +299,7 @@ end
#
# select a pivot, and partition v according to the pivot

function partition!(v::AbstractVector, lo::Int, hi::Int, o::Ordering)
function partition!(v, lo::Int, hi::Int, o::Ordering)
pivot = selectpivot!(v, lo, hi, o)
# pivot == v[lo], v[hi] > pivot
i, j = lo, hi
Expand All @@ -318,7 +318,7 @@ function partition!(v::AbstractVector, lo::Int, hi::Int, o::Ordering)
return j
end

function sort!(v::AbstractVector, lo::Int, hi::Int, a::QuickSortAlg, o::Ordering)
function sort!(v, lo::Int, hi::Int, a::QuickSortAlg, o::Ordering)
@inbounds while lo < hi
hi-lo <= SMALL_THRESHOLD && return sort!(v, lo, hi, SMALL_ALGORITHM, o)
j = partition!(v, lo, hi, o)
Expand All @@ -336,7 +336,7 @@ function sort!(v::AbstractVector, lo::Int, hi::Int, a::QuickSortAlg, o::Ordering
return v
end

function sort!(v::AbstractVector, lo::Int, hi::Int, a::MergeSortAlg, o::Ordering, t=similar(v,0))
function sort!(v, lo::Int, hi::Int, a::MergeSortAlg, o::Ordering, t=similar(v,0))
@inbounds if lo < hi
hi-lo <= SMALL_THRESHOLD && return sort!(v, lo, hi, SMALL_ALGORITHM, o)

Expand Down Expand Up @@ -401,7 +401,7 @@ end
# end


function sort!(v::AbstractVector, lo::Int, hi::Int, a::PartialQuickSort,
function sort!(v, lo::Int, hi::Int, a::PartialQuickSort,
o::Ordering)
@inbounds while lo < hi
hi-lo <= SMALL_THRESHOLD && return sort!(v, lo, hi, SMALL_ALGORITHM, o)
Expand All @@ -427,11 +427,16 @@ end

## generic sorting methods ##

defalg(v::AbstractArray) = DEFAULT_STABLE
defalg(v::AbstractArray{<:Number}) = DEFAULT_UNSTABLE
defalg(v) = defalg(eltype(v))
defalg(::Type) = DEFAULT_STABLE
defalg(::Type{<:Number}) = DEFAULT_UNSTABLE

function sort!(v::AbstractVector, alg::Algorithm, order::Ordering)
inds = indices(v,1)
sort!(v, alg::Algorithm, order::Ordering) = sort!(v, Base.iteratoraccess(v), alg, order)

sort!(v, access, alg::Algorithm, order::Ordering) = throw(ArgumentError("the collection must have setindex! defined"))

function sort!(v, ::WritableRandomAccess, alg::Algorithm, order::Ordering)
inds = linearindices(v)
sort!(v,first(inds),last(inds),alg,order)
end

Expand Down Expand Up @@ -473,7 +478,7 @@ julia> v = [(1, "c"), (3, "a"), (2, "b")]; sort!(v, by = x -> x[2]); v
(1, "c")
```
"""
function sort!(v::AbstractVector;
function sort!(v;
alg::Algorithm=defalg(v),
lt=isless,
by=identity,
Expand Down Expand Up @@ -538,7 +543,7 @@ julia> v
2
```
"""
sort(v::AbstractVector; kws...) = sort!(copymutable(v); kws...)
sort(v; kws...) = sort!(copymutable(v); kws...)

## selectperm: the permutation to sort the first k elements of an array ##

Expand Down