Skip to content

Commit

Permalink
Add fallback getindex for AbstractArray
Browse files Browse the repository at this point in the history
Enable fancy indexing behaviors by default for AbstractArray, only requiring subtypes to implement either:

    getindex(::T, ::Int) # if linearindexing(T) == LinearFast()
    getindex(::T, ::Int, ::Int, #=...ndims(A) indices...=#) # if LinearSlow()

(and, optionally, similar definitions for unsafe_getindex)
  • Loading branch information
mbauman committed Mar 23, 2015
1 parent fad0637 commit 7958f08
Show file tree
Hide file tree
Showing 6 changed files with 245 additions and 96 deletions.
154 changes: 116 additions & 38 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ similar (a::AbstractArray, T) = similar(a, T, size(a))
similar{T}(a::AbstractArray{T}, dims::Dims) = similar(a, T, dims)
similar{T}(a::AbstractArray{T}, dims::Int...) = similar(a, T, dims)
similar (a::AbstractArray, T, dims::Int...) = similar(a, T, dims)
# similar creates an Array by default
similar (a::AbstractArray, T, dims::Dims) = Array(T, dims)

function reshape(a::AbstractArray, dims::Dims)
if prod(dims) != length(a)
Expand Down Expand Up @@ -412,19 +414,6 @@ imag{T<:Real}(x::AbstractArray{T}) = zero(x)

\(A::Number, B::AbstractArray) = B ./ A

## Indexing: getindex ##

getindex(t::AbstractArray, i::Real) = error("indexing not defined for ", typeof(t))

# linear indexing with a single multi-dimensional index
function getindex(A::AbstractArray, I::AbstractArray)
x = similar(A, size(I))
for i=1:length(I)
x[i] = A[I[i]]
end
return x
end

# index A[:,:,...,i,:,:,...] where "i" is in dimension "d"
# TODO: more optimized special cases
slicedim(A::AbstractArray, d::Integer, i) =
Expand Down Expand Up @@ -479,37 +468,126 @@ setindex!(t::AbstractArray, x, i::Real) =
error("setindex! not defined for ",typeof(t))
setindex!(t::AbstractArray, x) = throw(MethodError(setindex!, (t, x)))

## Indexing: handle more indices than dimensions if "extra" indices are 1

# Don't require vector/matrix subclasses to implement more than 1/2 indices,
# respectively, by handling the extra dimensions in AbstractMatrix.

function getindex(A::AbstractVector, i1,i2,i3...)
if i2*prod(i3) != 1
throw(BoundsError())
## Approach:
# We only define one fallback method on getindex for all argument types.
# That dispatches to an (inlined) internal _getindex function, where the goal is
# to transform the indices such that we can call the only getindex method that
# we require AbstractArray subtypes must define, either:
# getindex(::T, ::Int) # if linearindexing(T) == LinearFast()
# getindex(::T, ::Int, ::Int, #=...ndims(A) indices...=#) if LinearSlow()
# Unfortunately, it is currently impossible to express the latter method for
# arbitrary dimensionalities. We could get around that with ::CartesianIndex{N},
# but that isn't as obvious and would require that the function be inlined to
# avoid allocations. If the subtype hasn't defined those methods, it goes back
# to the _getindex function where an error is thrown to prevent stack overflows.
#
# We use the same scheme for unsafe_getindex, with the exception that we can
# fallback to the safe version if the subtype hasn't defined the required
# unsafe method.

macro _inline_expr()
Expr(:meta, :inline)
end

function getindex(A::AbstractArray, I...)
@_inline_expr
_getindex(linearindexing(A), A, I...)
end
function unsafe_getindex(A::AbstractArray, I...)
@_inline_expr
_unsafe_getindex(linearindexing(A), A, I...)
end
## Internal defitions
_getindex(::LinearFast, A::AbstractArray) = (@_inline_expr; getindex(A, 1))
_getindex(::LinearSlow, A::AbstractArray) = (@_inline_expr; _getindex(A, 1))
_unsafe_getindex(::LinearFast, A::AbstractArray) = (@_inline_expr; unsafe_getindex(A, 1))
_unsafe_getindex(::LinearSlow, A::AbstractArray) = (@_inline_expr; _unsafe_getindex(A, 1))
_getindex(::LinearIndexing, A::AbstractArray, I...) = error("indexing $(typeof(A)) with types $(typeof(I)) is not supported")
_unsafe_getindex(::LinearIndexing, A::AbstractArray, I...) = error("indexing $(typeof(A)) with types $(typeof(I)) is not supported")

## LinearFast Scalar indexing
_getindex(::LinearFast, A::AbstractArray, I::Int) = error("indexing not defined for ", typeof(A))
stagedfunction _getindex(::LinearFast, A::AbstractArray, I::Real...)
N = length(I)
Isplat = Expr[:(to_index(I[$d])) for d = 1:N]
quote
$(Expr(:meta, :inline))
# We must check bounds for sub2ind; so we can then call unsafe_getindex
checkbounds(A, I...)
unsafe_getindex(A, sub2ind(size(A), $(Isplat...)))
end
end
_unsafe_getindex(::LinearFast, A::AbstractArray, I::Int) = (@_inline_expr; getindex(A, I))
stagedfunction _unsafe_getindex(::LinearFast, A::AbstractArray, I::Real...)
N = length(I)
Isplat = Expr[:(to_index(I[$d])) for d = 1:N]
quote
$(Expr(:meta, :inline))
unsafe_getindex(A, sub2ind(size(A), $(Isplat...)))
end
end

# LinearSlow Scalar indexing
stagedfunction _getindex{T,AN}(::LinearSlow, A::AbstractArray{T,AN}, I::Real...)
N = length(I)
if N == AN
:(error("indexing not defined for ", typeof(A)))
elseif N > AN
# Drop trailing ones
Isplat = Expr[:(to_index(I[$d])) for d = 1:AN]
Osplat = Expr[:(to_index(I[$d]) == 1) for d = AN+1:N]
quote
$(Expr(:meta, :inline))
(&)($(Osplat...)) || throw(BoundsError(A, I))
getindex(A, $(Isplat...))
end
else
# Expand the last index into the appropriate number of indices
Isplat = Expr[:(to_index(I[$d])) for d = 1:N-1]
i = 0
for d=N:AN
push!(Isplat, :(s[$(i+=1)]))
end
sz = Expr(:tuple)
sz.args = Expr[:(size(A, $d)) for d=N:AN]
quote
$(Expr(:meta, :inline))
s = ind2sub($sz, to_index(I[$N]))
getindex(A, $(Isplat...))
end
end
A[i1]
end
function getindex(A::AbstractMatrix, i1,i2,i3,i4...)
if i3*prod(i4) != 1
throw(BoundsError())
stagedfunction _unsafe_getindex{T,AN}(::LinearSlow, A::AbstractArray{T,AN}, I::Real...)
N = length(I)
if N == AN
Isplat = Expr[:(to_index(I[$d])) for d = 1:N]
:(getindex(A, $(Isplat...)))
elseif N > AN
# Drop trailing dimensions (unchecked)
Isplat = Expr[:(to_index(I[$d])) for d = 1:AN]
quote
$(Expr(:meta, :inline))
unsafe_getindex(A, $(Isplat...))
end
else
# Expand the last index into the appropriate number of indices
Isplat = Expr[:(to_index(I[$d])) for d = 1:N-1]
for d=N:AN
push!(Isplat, :(s[$(d-N+1)]))
end
sz = Expr(:tuple)
sz.args = Expr[:(size(A, $d)) for d=N:AN]
quote
$(Expr(:meta, :inline))
s = ind2sub($sz, to_index(I[$N]))
unsafe_getindex(A, $(Isplat...))
end
end
A[i1,i2]
end

function setindex!(A::AbstractVector, x, i1,i2,i3...)
if i2*prod(i3) != 1
throw(BoundsError())
end
A[i1] = x
end
function setindex!(A::AbstractMatrix, x, i1,i2,i3,i4...)
if i3*prod(i4) != 1
throw(BoundsError())
end
A[i1,i2] = x
end

## Setindex! is defined similarly:
unsafe_setindex!(A::AbstractArray, v, I...) = (@_inline_expr; @inbounds return setindex!(A, v, I...))
## get (getindex with a default value) ##

typealias RangeVecIntList{A<:AbstractVector{Int}} Union((Union(Range, AbstractVector{Int})...), AbstractVector{UnitRange{Int}}, AbstractVector{Range{Int}}, AbstractVector{A})
Expand Down
Loading

0 comments on commit 7958f08

Please sign in to comment.