Skip to content
This repository has been archived by the owner on May 4, 2019. It is now read-only.

Commit

Permalink
Merge pull request #152 from mbauman/mb/10525
Browse files Browse the repository at this point in the history
Internal API changes due to 10525
  • Loading branch information
simonster committed Jun 22, 2015
2 parents b5d39c3 + 0bb50b4 commit e7172be
Showing 1 changed file with 64 additions and 12 deletions.
76 changes: 64 additions & 12 deletions src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,68 @@ function Base.to_index(A::DataArray)
end

# Fast implementation of checkbounds for DataArray input
Base.checkbounds(sz::Int, I::AbstractDataVector{Bool}) =
length(I) == sz || throw(BoundsError())
function Base.checkbounds{T<:Real}(sz::Int, I::AbstractDataArray{T})
anyna(I) && throw(NAException("cannot index into an array with a DataArray containing NAs"))
extr = daextract(I)
for i = 1:length(I)
@inbounds v = unsafe_getindex_notna(I, extr, i)
checkbounds(sz, v)
# This overrides an internal API that changed after #10525
if VERSION < v"0.4-dev+5194"
Base.checkbounds(sz::Int, I::AbstractDataVector{Bool}) =
length(I) == sz || throw(BoundsError())
function Base.checkbounds{T<:Real}(sz::Int, I::AbstractDataArray{T})
anyna(I) && throw(NAException("cannot index into an array with a DataArray containing NAs"))
extr = daextract(I)
for i = 1:length(I)
@inbounds v = unsafe_getindex_notna(I, extr, i)
checkbounds(sz, v)
end
end
else
Base._checkbounds(sz::Int, I::AbstractDataVector{Bool}) = length(I) == sz
function Base._checkbounds{T<:Real}(sz::Int, I::AbstractDataArray{T})
anyna(I) && throw(NAException("cannot index into an array with a DataArray containing NAs"))
extr = daextract(I)
b = true
for i = 1:length(I)
@inbounds v = unsafe_getindex_notna(I, extr, i)
b &= Base._checkbounds(sz, v)
end
b
end
end

# Indexing uses undocumented APIs to determine the resulting shape. These APIs
# changed to support indices like `:`, which need the array to know the shape
if VERSION < v"0.4-dev+5194" # Merge commit of Julialang/julia#10525
index_shape(A, I...) = Base.index_shape(I...)
_lengths() = ()
_lengths(i, I...) = tuple(length(i), _lengths(I...)...)
index_lengths(A, I...) = _lengths(I...)
function throw_setindex_mismatch(X, I)
if length(I) == 1
throw(DimensionMismatch("tried to assign $(length(X)) elements to $(I[1]) destinations"))
else
throw(DimensionMismatch("tried to assign $(Base.dims2string(size(X))) array to $(Base.dims2string(I)) destination"))
end
end
setindex_shape_check(X::AbstractArray) =
(length(X)==1 || throw_setindex_mismatch(X,()))
setindex_shape_check(X::AbstractArray, i::Int) =
(length(X)==i || throw_setindex_mismatch(X, (i,)))
setindex_shape_check{T}(X::AbstractArray{T,1}, i::Int) =
(length(X)==i || throw_setindex_mismatch(X, (i,)))
setindex_shape_check{T}(X::AbstractArray{T,1}, i::Int, j::Int) =
(length(X)==i*j || throw_setindex_mismatch(X, (i,j)))
function setindex_shape_check{T}(X::AbstractArray{T,2}, i::Int, j::Int)
if length(X) != i*j
throw_setindex_mismatch(X, (i,j))
end
sx1 = size(X,1)
if !(i == 1 || i == sx1 || sx1 == 1)
throw_setindex_mismatch(X, (i,j))
end
end
setindex_shape_check(X, I::Int...) = nothing # Non-arrays broadcast to all idxs
else
import Base: index_shape, index_lengths, setindex_shape_check
end

# Fallbacks to avoid ambiguity
setindex!(t::AbstractDataArray, x, i::Real) =
throw(MethodError(setindex!, typeof(t), typeof(x), typeof(i)))
Expand Down Expand Up @@ -144,7 +195,7 @@ end
end

function _getindex{T}(A::DataArray{T}, I::@compat Tuple{Vararg{Union(Int,AbstractVector)}})
shape = Base.index_shape(I...)
shape = index_shape(A, I...)
_getindex!(DataArray(Array(T, shape), falses(shape)), A, I...)
end

Expand Down Expand Up @@ -245,21 +296,22 @@ end
end
else
X = x
@ncall N Base.setindex_shape_check X I
idxlens = @ncall N index_lengths A I
@ncall N setindex_shape_check X (d->idxlens[d])
k = 1
if isa(A, PooledDataArray) && isa(X, PooledDataArray)
# When putting one PDA into another, first unify the pools
# and then translate the references
poolmap = combine_pools!(A.pool, X.pool)
Arefs = A.refs
Xrefs = X.refs
@nloops N i d->(1:length(I_d)) d->(@inbounds offset_{d-1} = offset_d + (Base.unsafe_getindex(I_d, i_d)-1)*stride_d) begin
@nloops N i d->(1:idxlens[d]) d->(@inbounds offset_{d-1} = offset_d + (Base.unsafe_getindex(I_d, i_d)-1)*stride_d) begin
@inbounds Arefs[offset_0] = Xrefs[k] == 0 ? 0 : poolmap[Xrefs[k]]
k += 1
end
else
Xextr = daextract(X)
@nloops N i d->(1:length(I_d)) d->(@inbounds offset_{d-1} = offset_d + (Base.unsafe_getindex(I_d, i_d)-1)*stride_d) begin
@nloops N i d->(1:idxlens[d]) d->(@inbounds offset_{d-1} = offset_d + (Base.unsafe_getindex(I_d, i_d)-1)*stride_d) begin
@inbounds if isa(X, AbstractDataArray) && unsafe_isna(X, Xextr, k)
unsafe_setna!(A, Aextr, offset_0)
else
Expand Down

0 comments on commit e7172be

Please sign in to comment.