Skip to content

Commit

Permalink
revised getindex(::AbstractArray,::TupleVector)
Browse files Browse the repository at this point in the history
  • Loading branch information
chakravala committed Jul 29, 2023
1 parent 8730268 commit 75cf4ef
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/abstractvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,18 @@ similar_type(::Type{SA}) where {SA<:TupleVector} = similar_type(SA,eltype(SA))
similar_type(::SA,::Type{T}) where {SA<:TupleVector{N},T} where N = similar_type(SA,T,Val(N))
similar_type(::Type{SA},::Type{T}) where {SA<:TupleVector{N},T} where N = similar_type(SA,T,Val(N))

similar_type(::A,n::Val) where {A<:AbstractVector} = similar_type(A,eltype(A),n)
similar_type(::Type{A},n::Val) where {A<:AbstractVector} = similar_type(A,eltype(A),n)
similar_type(::A,n::Val) where {A<:AbstractArray} = similar_type(A,eltype(A),n)
similar_type(::Type{A},n::Val) where {A<:AbstractArray} = similar_type(A,eltype(A),n)

similar_type(::A,::Type{T},n::Val) where {A<:AbstractVector,T} = similar_type(A,T,n)
similar_type(::A,::Type{T},n::Val) where {A<:AbstractArray,T} = similar_type(A,T,n)

# We should be able to deal with SOneTo axes
@pure similar_type(s::SOneTo) = similar_type(typeof(s))
@pure similar_type(::Type{SOneTo{n}}) where n = similar_type(SOneTo{n}, Int, Val(n))

# Default types
# Generally, use TupleVector
similar_type(::Type{A},::Type{T},n::Val) where {A<:AbstractVector,T} = default_similar_type(T,n)
similar_type(::Type{A},::Type{T},n::Val) where {A<:AbstractArray,T} = default_similar_type(T,n)
default_similar_type(::Type{T},::Val{N}) where {N,T} = Values{N,T}

similar_type(::Type{SA},::Type{T},n::Val) where {SA<:Variables,T} = mutable_similar_type(T,n)
Expand Down
145 changes: 145 additions & 0 deletions src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,38 @@ end
@inline index_size(s::Val, a::SOneTo{n}) where n = Val(n)

@inline index_sizes(::S, ind) where {S<:Val} = index_size(S, ind)
@inline index_sizes(::S, inds...) where {S<:Val} = map(index_size, S, inds)

@inline index_sizes(::Int) = Val(1)
@inline index_sizes(a::TupleVector{N}) where N = Val(N)

@inline index_sizes2() = ()
@inline index_sizes2(::Int, inds...) = (Val(()), index_sizes2(inds...)...)
@inline index_sizes2(a::TupleVector{N}, inds...) where N = (Val(N), index_sizes2(inds...)...)

out_index_size(ind_size::Type{Val{N}}) where N = Val(N)
linear_index_size(ind_size::Type{Val{N}}) where N = N

out_index_size2(ind_sizes::Type{<:Val}...) = Val(_out_index_size2((), ind_sizes...))
@inline _out_index_size2(t::Tuple) = t
@inline _out_index_size2(t::Tuple, ::Type{Val{S}}, ind_sizes...) where {S} = _out_index_size2((t..., S...), ind_sizes...)

linear_index_size2(ind_sizes::Type{<:Val}...) = _linear_index_size2((), ind_sizes...)
@inline _linear_index_size2(t::Tuple) = t
@inline _linear_index_size2(t::Tuple, ::Type{Val{S}}, ind_sizes...) where {S} = _linear_index_size2((t..., prod(S)), ind_sizes...)

untuple(::Val{t}) where t = Val(t[1])

_ind(::Int, ::Type{Int}) = :ind
_ind(i::Int, ::Type{<:TupleVector}) = :(ind[$i])
_ind(j::Int, ::Type{Colon}) = j
_ind(j::Int, ::Type{<:SOneTo}) = j

_ind2(i::Int, ::Int, ::Type{Int}) = :(inds[$i])
_ind2(i::Int, j::Int, ::Type{<:TupleVector}) = :(inds[$i][$j])
_ind2(i::Int, j::Int, ::Type{Colon}) = j
_ind2(i::Int, j::Int, ::Type{<:SOneTo}) = j

################################
## Non-scalar linear indexing ##
################################
Expand Down Expand Up @@ -178,6 +198,11 @@ function Base._getindex(::IndexStyle, A::AbstractVector, ind::TupleIndexing)
return StaticVectors._getindex(A, index_sizes(unwrap(ind)), unwrap(ind))
end

function Base._getindex(::IndexStyle, A::AbstractArray, i1::TupleIndexing, I::TupleIndexing...)
inds = (unwrap(i1), map(unwrap, I)...)
return StaticVectors._getindex(A, index_sizes2(inds...), inds)
end

@generated function _getindex(a::AbstractVector, ind_size::Val, ind)
newsize = out_index_size(ind_size)
linearsize = linear_index_size(ind_size)
Expand All @@ -188,6 +213,42 @@ end
end
end

@generated function _getindex(a::AbstractArray, ind_sizes::Tuple{Vararg{Val}}, inds)
newsize = untuple(out_index_size2(ind_sizes.parameters...))
linearsizes = linear_index_size2(ind_sizes.parameters...)
exprs = Array{Expr}(undef, linearsizes)

# Iterate over input indices
ind_types = inds.parameters
current_ind = ones(Int,length(linearsizes))
more = !isempty(exprs)
while more
exprs_tmp = [_ind2(i, current_ind[i], ind_types[i]) for i = 1:length(linearsizes)]
exprs[current_ind...] = :(getindex(a, $(exprs_tmp...)))

# increment current_ind
current_ind[1] += 1
for i 1:length(linearsizes)
if current_ind[i] > linearsizes[i]
if i == length(linearsizes)
more = false
break
else
current_ind[i] = 1
current_ind[i+1] += 1
end
else
break
end
end
end

quote
Base.@_propagate_inbounds_meta
similar_type(a, $newsize)(tuple($(exprs...)))
end
end

# setindex!

@propagate_inbounds function Base.setindex!(a::TupleVector{N}, value, ind::Union{Int, TupleVector{M, Int} where M, Colon}) where N
Expand All @@ -198,6 +259,11 @@ function Base._setindex!(::IndexStyle, a::AbstractVector, value, ind::TupleIndex
return StaticVectors._setindex!(a, value, index_sizes(ind), unwrap(ind))
end

function Base._setindex!(::IndexStyle, a::AbstractArray, value, i1::TupleIndexing, I::TupleIndexing...)
inds = (unwrap(i1), map(unwrap, I)...)
return StaticVectors._setindex!(a, value, index_sizes2(inds...), inds)
end

# setindex! from a scalar
@generated function _setindex!(a::AbstractVector, value, ind_sizes::Val, ind)
linearsize = linear_index_size(ind_size)
Expand All @@ -209,6 +275,42 @@ end
end
end

@generated function _setindex!(a::AbstractArray, value, ind_sizes::Tuple{Vararg{Val}}, inds)
linearsizes = linear_index_size2(ind_sizes.parameters...)
exprs = Array{Expr}(undef, linearsizes)

# Iterate over input indices
ind_types = inds.parameters
current_ind = ones(Int,length(ind_types))
more = !isempty(exprs)
while more
exprs_tmp = [_ind2(i, current_ind[i], ind_types[i]) for i = 1:length(ind_types)]
exprs[current_ind...] = :(setindex!(a, value, $(exprs_tmp...)))

# increment current_ind
current_ind[1] += 1
for i 1:length(linearsizes)
if current_ind[i] > linearsizes[i]
if i == length(linearsizes)
more = false
break
else
current_ind[i] = 1
current_ind[i+1] += 1
end
else
break
end
end
end

quote
Base.@_propagate_inbounds_meta
$(exprs...)
return a
end
end

# setindex! from an array
@generated function _setindex!(a::AbstractVector, v::AbstractVector, ind_size::Val, ind)
linearsize = linear_index_size(ind_size)
Expand All @@ -225,6 +327,48 @@ end
end
end

@generated function _setindex!(a::AbstractArray, v::AbstractArray, ind_sizes::Tuple{Vararg{Val}}, inds)
linearsizes = linear_index_size2(ind_sizes.parameters...)
exprs = Array{Expr}(undef, linearsizes)

# Iterate over input indices
ind_types = inds.parameters
current_ind = ones(Int,length(ind_types))
more = true
j = 1
while more
exprs_tmp = [_ind2(i, current_ind[i], ind_types[i]) for i = 1:length(ind_types)]
exprs[current_ind...] = :(setindex!(a, v[$j], $(exprs_tmp...)))

# increment current_ind
current_ind[1] += 1
for i 1:length(linearsizes)
if current_ind[i] > linearsizes[i]
if i == length(linearsizes)
more = false
break
else
current_ind[i] = 1
current_ind[i+1] += 1
end
else
break
end
end
j += 1
end

quote
Base.@_propagate_inbounds_meta
if length(v) != $(prod(linearsizes))
newsize = $(linearsizes)
throw(DimensionMismatch("tried to assign $(length(v))-element array to $newsize destination"))
end
$(exprs...)
return a
end
end

# checkindex

Base.checkindex(B::Type{Bool}, inds::AbstractUnitRange, i::TupleIndexing{T}) where T = Base.checkindex(B, inds, unwrap(i))
Expand All @@ -234,6 +378,7 @@ Base.checkindex(B::Type{Bool}, inds::AbstractUnitRange, i::TupleIndexing{T}) whe
# unsafe_view need only deal with vargs of `TupleIndexing`, as wrapped by to_indices.
# i1 is explicitly specified to avoid ambiguities with Base
Base.unsafe_view(A::AbstractVector, ind::TupleIndexing) = Base.unsafe_view(A, unwrap(ind))
Base.unsafe_view(A::AbstractArray, i1::TupleIndexing, indices::TupleIndexing...) = Base.unsafe_view(A, unwrap(i1), map(unwrap, indices)...)

# Views of views need a new method for Base.SubArray because storing indices
# wrapped in TupleIndexing in field indices of SubArray causes all sorts of problems.
Expand Down

2 comments on commit 75cf4ef

@chakravala
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/88605

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.0.3 -m "<description of version>" 75cf4ef0f674f1af6b10ff42b3add2a4166f437b
git push origin v1.0.3

Please sign in to comment.