From 3ea63102d3f7ac253b3c6e85431e7d97ec2699cb Mon Sep 17 00:00:00 2001 From: Rafael Fourquet Date: Fri, 23 Jun 2017 15:52:19 +0200 Subject: [PATCH 1/2] add IteratorAccess trait: for indexing capabilities --- base/generator.jl | 17 ++++++++++++ base/iterators.jl | 71 +++++++++++++++++++++++++++++++++++++++++++++-- base/sort.jl | 31 ++++++++++++--------- 3 files changed, 103 insertions(+), 16 deletions(-) diff --git a/base/generator.jl b/base/generator.jl index d49ed635ff3f1..42ecf57f941e9 100644 --- a/base/generator.jl +++ b/base/generator.jl @@ -108,11 +108,28 @@ Base.HasEltype() iteratoreltype(x) = iteratoreltype(typeof(x)) iteratoreltype(::Type) = HasEltype() # HasEltype is the default +abstract type IteratorAccess end +struct ForwardAccess <: IteratorAccess end +struct RandomAccess <: IteratorAccess end +struct WritableRandomAccess <: IteratorAccess end + +iteratoraccess(x) = iteratoraccess(typeof(x)) +iteratoraccess(::Type) = ForwardAccess() # ForwardAccess is the default + +removewritable(::ForwardAccess) = ForwardAccess() +removewritable(::Union{RandomAccess,WritableRandomAccess}) = RandomAccess() + iteratorsize(::Type{<:AbstractArray}) = HasShape() iteratorsize(::Type{Generator{I,F}}) where {I,F} = iteratorsize(I) + +iteratoraccess(::Type{<:AbstractArray}) = RandomAccess() +iteratoraccess(::Type{<:Array}) = WritableRandomAccess() +iteratoraccess(::Type{Generator{I,F}}) where{I,F} = removewritable(iteratoraccess(I)) + length(g::Generator) = length(g.iter) size(g::Generator) = size(g.iter) indices(g::Generator) = indices(g.iter) ndims(g::Generator) = ndims(g.iter) +getindex(g::Generator, key...) = map(g.f, g.iter[key...]) iteratoreltype(::Type{Generator{I,T}}) where {I,T} = EltypeUnknown() diff --git a/base/iterators.jl b/base/iterators.jl index 49228be91a562..7367b0bc73d94 100644 --- a/base/iterators.jl +++ b/base/iterators.jl @@ -2,9 +2,10 @@ module Iterators -import Base: start, done, next, isempty, length, size, eltype, iteratorsize, iteratoreltype, indices, ndims +import Base: start, done, next, isempty, length, size, eltype, iteratorsize, iteratoreltype, iteratoraccess, indices, ndims, getindex, setindex! -using Base: tuple_type_cons, SizeUnknown, HasLength, HasShape, IsInfinite, EltypeUnknown, HasEltype, OneTo, @propagate_inbounds +using Base: tuple_type_cons, SizeUnknown, HasLength, HasShape, IsInfinite, EltypeUnknown, HasEltype, + ForwardAccess, RandomAccess, WritableRandomAccess, removewritable, OneTo, @propagate_inbounds export enumerate, zip, rest, countfrom, take, drop, cycle, repeated, product, flatten, partition @@ -26,6 +27,12 @@ and_iteratorsize(a, b) = SizeUnknown() and_iteratoreltype(iel::T, ::T) where {T} = iel and_iteratoreltype(a, b) = EltypeUnknown() +and_iteratoraccess(::ForwardAccess, b) = ForwardAccess() +and_iteratoraccess(a, ::ForwardAccess) = ForwardAccess() +and_iteratoraccess(::RandomAccess, b) = RandomAccess() +and_iteratoraccess(a, ::RandomAccess) = RandomAccess() +and_iteratoraccess(::WritableRandomAccess, ::WritableRandomAccess) = WritableRandomAccess() + # enumerate struct Enumerate{I} @@ -67,8 +74,11 @@ end eltype(::Type{Enumerate{I}}) where {I} = Tuple{Int, eltype(I)} +getindex(e::Enumerate, key...) = getindex(zip(1:length(e), e.itr), key...) + iteratorsize(::Type{Enumerate{I}}) where {I} = iteratorsize(I) iteratoreltype(::Type{Enumerate{I}}) where {I} = iteratoreltype(I) +iteratoraccess(::Type{Enumerate{I}}) where {I} = removewritable(iteratoreltype(I)) struct IndexValue{I,A<:AbstractArray} data::A @@ -166,8 +176,12 @@ eltype(::Type{Zip1{I}}) where {I} = Tuple{eltype(I)} end @inline done(z::Zip1, st) = done(z.a,st) +getindex(z::Zip1, key...) = (z.a[key...],) +setindex!(z::Zip1, value::Tuple{I}, key...) where {I} = z.a[key...] = value[1] + iteratorsize(::Type{Zip1{I}}) where {I} = iteratorsize(I) iteratoreltype(::Type{Zip1{I}}) where {I} = iteratoreltype(I) +iteratoraccess(::Type{Zip1{I}}) where {I} = iteratoreltype(I) struct Zip2{I1, I2} <: AbstractZipIterator a::I1 @@ -186,8 +200,13 @@ eltype(::Type{Zip2{I1,I2}}) where {I1,I2} = Tuple{eltype(I1), eltype(I2)} end @inline done(z::Zip2, st) = done(z.a,st[1]) | done(z.b,st[2]) +getindex(z::Zip2, key...) = (z.a[key...], z.b[key...]) +setindex!(z::Zip2, value::Tuple{I1,I2}, key...) where {I1,I2} = (z.a[key...] = value[1]; + z.b[key...] = value[2]) + iteratorsize(::Type{Zip2{I1,I2}}) where {I1,I2} = zip_iteratorsize(iteratorsize(I1),iteratorsize(I2)) iteratoreltype(::Type{Zip2{I1,I2}}) where {I1,I2} = and_iteratoreltype(iteratoreltype(I1),iteratoreltype(I2)) +iteratoraccess(::Type{Zip2{I1,I2}}) where {I1,I2} = and_iteratoraccess(iteratoraccess(I1),iteratoraccess(I2)) struct Zip{I, Z<:AbstractZipIterator} <: AbstractZipIterator a::I @@ -237,8 +256,17 @@ eltype(::Type{Zip{I,Z}}) where {I,Z} = tuple_type_cons(eltype(I), eltype(Z)) end @inline done(z::Zip, st) = done(z.a,st[1]) | done(z.z,st[2]) +getindex(z::Zip, key...) = tuple(z.a[key...], z.z[key...]...) +setindex!(z::Zip, value::Tuple, key...) = _setindex!(z, value, 1, key) +_setindex!(z::Zip, value, n::Int, key) = (z.a[key...] = value[n]; + _setindex!(z.z, value, n+1, key)) +_setindex!(z::Zip2, value, n::Int, key) = (z.a[key...] = value[n]; + z.b[key...] = value[n+1]) + iteratorsize(::Type{Zip{I1,I2}}) where {I1,I2} = zip_iteratorsize(iteratorsize(I1),iteratorsize(I2)) iteratoreltype(::Type{Zip{I1,I2}}) where {I1,I2} = and_iteratoreltype(iteratoreltype(I1),iteratoreltype(I2)) +iteratoraccess(::Type{Zip{I1,I2}}) where {I1,I2} = and_iteratoraccess(iteratoraccess(I1),iteratoraccess(I2)) + # filter @@ -343,7 +371,10 @@ start(it::Count) = it.start next(it::Count, state) = (state, state + it.step) done(it::Count, state) = false +getindex(it::Count, key::Integer) = it.start + (key-1)*it.step + iteratorsize(::Type{<:Count}) = IsInfinite() +iteratoraccess(::Type{<:Count}) = RandomAccess() # Take -- iterate through the first n elements @@ -387,6 +418,8 @@ take_iteratorsize(::SizeUnknown) = SizeUnknown() iteratorsize(::Type{Take{I}}) where {I} = take_iteratorsize(iteratorsize(I)) length(t::Take) = _min_length(t.xs, 1:t.n, iteratorsize(t.xs), HasLength()) +iteratoraccess(::Type{Take{I}}) where {I} = iteratoraccess(I) + start(it::Take) = (it.n, start(it.xs)) function next(it::Take, state) @@ -400,6 +433,10 @@ function done(it::Take, state) return n <= 0 || done(it.xs, xs_state) end +getindex(it::Take, key::Integer) = key > it.n ? throw(BoundsError(it, key)) : it.xs[key] +setindex!(it::Take, value, key::Integer) = + key > it.n ? throw(BoundsError(it, key)) : it.xs[key] = value + # Drop -- iterator through all but the first n elements struct Drop{I} @@ -443,6 +480,8 @@ drop_iteratorsize(::IsInfinite) = IsInfinite() iteratorsize(::Type{Drop{I}}) where {I} = drop_iteratorsize(iteratorsize(I)) length(d::Drop) = _diff_length(d.xs, 1:d.n, iteratorsize(d.xs), HasLength()) +iteratoraccess(::Type{Drop{I}}) where {I} = iteratoraccess(I) + function start(it::Drop) xs_state = start(it.xs) for i in 1:it.n @@ -458,6 +497,9 @@ end next(it::Drop, state) = next(it.xs, state) done(it::Drop, state) = done(it.xs, state) +getindex(it::Drop, key::Integer) = it.xs[key+it.n] +setindex!(it::Drop, value, key::Integer) = it.xs[key+it.n] = value + # Cycle an iterator forever struct Cycle{I} @@ -474,6 +516,10 @@ cycle(xs) = Cycle(xs) eltype(::Type{Cycle{I}}) where {I} = eltype(I) iteratoreltype(::Type{Cycle{I}}) where {I} = iteratoreltype(I) iteratorsize(::Type{Cycle{I}}) where {I} = IsInfinite() +iteratoraccess(::Type{Cycle{I}}) where {I} = cycle_iteratoraccess(I, iteratorsize(I)) + +cycle_iteratoraccess(I, _) = iteratoraccess(I) +cycle_iteratoraccess(I, ::SizeUnknown) = ForwardAccess() function start(it::Cycle) s = start(it.xs) @@ -491,6 +537,16 @@ end done(it::Cycle, state) = state[2] +getindex(it::Cycle, key::Integer) = _getindex(it, key, iteratorsize(I)) +_getindex(it::Cycle, key::Integer, ::Union{HasLength,HasShape}) = + it.xs[mod1(key, length(it.xs))] +_getindex(it::Cycle, key::Integer, ::IsInfinite) = it.xs[key] + +setindex!(it::Cycle, value, key::Integer) = _setindex!(it, value, key, iteratorsize(I)) +_setindex!(it::Cycle, value, key::Integer, ::Union{HasLength,HasShape}) = + it.xs[mod1(key, length(it.xs))] = value +_setindex!(it::Cycle, value, key::Integer, ::IsInfinite) = it.xs[key] = value + # Repeated - repeat an object infinitely many times @@ -524,9 +580,11 @@ start(it::Repeated) = nothing next(it::Repeated, state) = (it.x, nothing) done(it::Repeated, state) = false +getindex(it::Repeated, key::Integer) = it.x + iteratorsize(::Type{<:Repeated}) = IsInfinite() iteratoreltype(::Type{<:Repeated}) = HasEltype() - +iteratoraccess(::Type{<:Repeated}) = RandomAccess() # Product -- cartesian product of iterators @@ -578,8 +636,12 @@ indices(p::Prod1) = _prod_indices(p.a, iteratorsize(p.a)) end @inline done(p::Prod1, st) = done(p.a, st) +getindex(p::Prod1, key...) = (p.a[key...],) +setindex!(p::Prod1, value::Tuple{T}, key...) where {T} = p.a[key...] = value[1] + iteratoreltype(::Type{Prod1{I}}) where {I} = iteratoreltype(I) iteratorsize(::Type{Prod1{I}}) where {I} = iteratorsize(I) +iteratoraccess(::Type{Prod1{I}}) where {I} = iteratoraccess(I) # two iterators struct Prod2{I1, I2} <: AbstractProdIterator @@ -607,6 +669,7 @@ eltype(::Type{Prod2{I1,I2}}) where {I1,I2} = Tuple{eltype(I1), eltype(I2)} iteratoreltype(::Type{Prod2{I1,I2}}) where {I1,I2} = and_iteratoreltype(iteratoreltype(I1),iteratoreltype(I2)) iteratorsize(::Type{Prod2{I1,I2}}) where {I1,I2} = prod_iteratorsize(iteratorsize(I1),iteratorsize(I2)) +iteratoraccess(::Type{Prod2{I1,I2}}) where {I1,I2} = and_iteratoraccess(iteratoraccess(I1), iteratoraccess(I2)) function start(p::AbstractProdIterator) s1, s2 = start(p.a), start(p.b) @@ -633,6 +696,8 @@ end @inline next(p::Prod2, st) = prod_next(p, st) @inline done(p::AbstractProdIterator, st) = st[4] +# TODO: write getindex/setindex! for Prod2 & Prod + # n iterators struct Prod{I1, I2<:AbstractProdIterator} <: AbstractProdIterator a::I1 diff --git a/base/sort.jl b/base/sort.jl index 846723518a9b7..459f066e9db43 100644 --- a/base/sort.jl +++ b/base/sort.jl @@ -2,7 +2,7 @@ module Sort -using Base: Order, Checked, copymutable, linearindices, IndexStyle, viewindexing, IndexLinear, _length +using Base: Order, Checked, copymutable, linearindices, IndexStyle, viewindexing, IndexLinear, _length, WritableRandomAccess import Base.sort, @@ -244,7 +244,7 @@ const DEFAULT_STABLE = MergeSort const SMALL_ALGORITHM = InsertionSort const SMALL_THRESHOLD = 20 -function sort!(v::AbstractVector, lo::Int, hi::Int, ::InsertionSortAlg, o::Ordering) +function sort!(v, lo::Int, hi::Int, ::InsertionSortAlg, o::Ordering) @inbounds for i = lo+1:hi j = i x = v[i] @@ -269,7 +269,7 @@ end # Upon return, the pivot is in v[lo], and v[hi] is guaranteed to be # greater than the pivot -@inline function selectpivot!(v::AbstractVector, lo::Int, hi::Int, o::Ordering) +@inline function selectpivot!(v, lo::Int, hi::Int, o::Ordering) @inbounds begin mi = (lo+hi)>>>1 @@ -299,7 +299,7 @@ end # # select a pivot, and partition v according to the pivot -function partition!(v::AbstractVector, lo::Int, hi::Int, o::Ordering) +function partition!(v, lo::Int, hi::Int, o::Ordering) pivot = selectpivot!(v, lo, hi, o) # pivot == v[lo], v[hi] > pivot i, j = lo, hi @@ -318,7 +318,7 @@ function partition!(v::AbstractVector, lo::Int, hi::Int, o::Ordering) return j end -function sort!(v::AbstractVector, lo::Int, hi::Int, a::QuickSortAlg, o::Ordering) +function sort!(v, lo::Int, hi::Int, a::QuickSortAlg, o::Ordering) @inbounds while lo < hi hi-lo <= SMALL_THRESHOLD && return sort!(v, lo, hi, SMALL_ALGORITHM, o) j = partition!(v, lo, hi, o) @@ -336,7 +336,7 @@ function sort!(v::AbstractVector, lo::Int, hi::Int, a::QuickSortAlg, o::Ordering return v end -function sort!(v::AbstractVector, lo::Int, hi::Int, a::MergeSortAlg, o::Ordering, t=similar(v,0)) +function sort!(v, lo::Int, hi::Int, a::MergeSortAlg, o::Ordering, t=similar(v,0)) @inbounds if lo < hi hi-lo <= SMALL_THRESHOLD && return sort!(v, lo, hi, SMALL_ALGORITHM, o) @@ -401,7 +401,7 @@ end # end -function sort!(v::AbstractVector, lo::Int, hi::Int, a::PartialQuickSort, +function sort!(v, lo::Int, hi::Int, a::PartialQuickSort, o::Ordering) @inbounds while lo < hi hi-lo <= SMALL_THRESHOLD && return sort!(v, lo, hi, SMALL_ALGORITHM, o) @@ -427,11 +427,16 @@ end ## generic sorting methods ## -defalg(v::AbstractArray) = DEFAULT_STABLE -defalg(v::AbstractArray{<:Number}) = DEFAULT_UNSTABLE +defalg(v) = defalg(eltype(v)) +defalg(::Type) = DEFAULT_STABLE +defalg(::Type{<:Number}) = DEFAULT_UNSTABLE -function sort!(v::AbstractVector, alg::Algorithm, order::Ordering) - inds = indices(v,1) +sort!(v, alg::Algorithm, order::Ordering) = sort!(v, Base.iteratoraccess(v), alg, order) + +sort!(v, access, alg::Algorithm, order::Ordering) = throw(ArgumentError("the collection must have setindex! defined")) + +function sort!(v, ::WritableRandomAccess, alg::Algorithm, order::Ordering) + inds = linearindices(v) sort!(v,first(inds),last(inds),alg,order) end @@ -473,7 +478,7 @@ julia> v = [(1, "c"), (3, "a"), (2, "b")]; sort!(v, by = x -> x[2]); v (1, "c") ``` """ -function sort!(v::AbstractVector; +function sort!(v; alg::Algorithm=defalg(v), lt=isless, by=identity, @@ -538,7 +543,7 @@ julia> v 2 ``` """ -sort(v::AbstractVector; kws...) = sort!(copymutable(v); kws...) +sort(v; kws...) = sort!(copymutable(v); kws...) ## selectperm: the permutation to sort the first k elements of an array ## From cb62824794724ed443c8e7d493d95a60c12ae94d Mon Sep 17 00:00:00 2001 From: Rafael Fourquet Date: Fri, 23 Jun 2017 16:43:58 +0200 Subject: [PATCH 2/2] add doc for iteratoraccess --- base/generator.jl | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/base/generator.jl b/base/generator.jl index 42ecf57f941e9..0a4b23aff8838 100644 --- a/base/generator.jl +++ b/base/generator.jl @@ -113,6 +113,28 @@ struct ForwardAccess <: IteratorAccess end struct RandomAccess <: IteratorAccess end struct WritableRandomAccess <: IteratorAccess end +""" + iteratoraccess(itertype::Type) -> IteratorAccess + +Given the type of an iterator, returns one of the following values: + +* `ForwardAccess()` if the iterator can be iterated over. +* `RandomAccess()` if the iterator supports read-only indexing. +* `WritableRandomAccess()` if the iterator supports read-write indexing. + +The default value (for iterators that do not define this function) is `ForwardAccess()`. + +```jldoctest +julia> Base.iteratoraccess(1:5) +Base.RandomAccess() + +julia> Base.iteratoraccess([1, 2, 3]) +Base.WritableRandomAccess() + +julia> Base.iteratoraccess(drop([1, 2, 3], 1)) +Base.WritableRandomAccess() +``` +""" iteratoraccess(x) = iteratoraccess(typeof(x)) iteratoraccess(::Type) = ForwardAccess() # ForwardAccess is the default