Skip to content

Commit

Permalink
Allow non-static indices
Browse files Browse the repository at this point in the history
  • Loading branch information
tkf committed Dec 22, 2019
1 parent 103e9d4 commit 3745b09
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 4 deletions.
6 changes: 3 additions & 3 deletions src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ end
@inline index_size(::Size, ::Int) = Size()
@inline index_size(::Size, a::StaticArray) = Size(a)
@inline index_size(s::Size, ::Colon) = s
@inline index_size(s::Size, a::SOneTo{n}) where n = Size(n,)
@inline index_size(::Size, a::AbstractRange{<:Integer}) = Size(length(a),)

@inline index_sizes(::S, inds...) where {S<:Size} = map(index_size, unpack_size(S), inds)

Expand All @@ -92,9 +92,9 @@ linear_index_size(ind_sizes::Type{<:Size}...) = _linear_index_size((), ind_sizes
@inline _linear_index_size(t::Tuple, ::Type{Size{S}}, ind_sizes...) where {S} = _linear_index_size((t..., prod(S)), ind_sizes...)

_ind(i::Int, ::Int, ::Type{Int}) = :(inds[$i])
_ind(i::Int, j::Int, ::Type{<:StaticArray}) = :(inds[$i][$j])
_ind(i::Int, j::Int, ::Type{Colon}) = j
_ind(i::Int, j::Int, ::Type{<:SOneTo}) = j
_ind(i::Int, j::Int, ::Type{<:AbstractArray}) = :(inds[$i][$j])

################################
## Non-scalar linear indexing ##
Expand Down Expand Up @@ -215,7 +215,7 @@ end

# getindex

@propagate_inbounds function getindex(a::StaticArray, inds::Union{Int, StaticArray{<:Tuple, Int}, SOneTo, Colon}...)
@propagate_inbounds function getindex(a::StaticArray, inds::Union{Int, StaticArray{<:Tuple, Int}, AbstractRange, Colon}...)
_getindex(a, index_sizes(Size(a), inds...), inds)
end

Expand Down
2 changes: 1 addition & 1 deletion test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ using StaticArrays, Test, LinearAlgebra
@test similar(v, SOneTo(3), SOneTo(4)) isa MMatrix{3,4,Int}
@test similar(v, 3, SOneTo(4)) isa Matrix

@test m[:, 1:2] isa Matrix
@test m[:, 1:2] isa SMatrix{2, 2, Int}
@test m[:, [true, false, false]] isa Matrix
@test m[:, SOneTo(2)] isa SMatrix{2, 2, Int}
@test m[:, :] isa SMatrix{2, 3, Int}
Expand Down
30 changes: 30 additions & 0 deletions test/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,4 +223,34 @@ using StaticArrays, Test
@test eltype(Bvv) == Int
@test Bvv[:] == [B[1,2,3,4], B[1,1,3,4]]
end

@testset "Indexing with constants" begin
function SVector_UnitRange()
x = SA[1, 2, 3]
x[2:end]
end
@test SVector_UnitRange() === SA[2, 3]
VERSION v"1.1" && @test_const_fold SVector_UnitRange()

function SVector_StepRange()
x = SA[1, 2, 3, 4]
x[1:2:end]
end
@test SVector_StepRange() === SA[1, 3]
VERSION v"1.1" && @test_const_fold SVector_StepRange()

function SMatrix_UnitRange_UnitRange()
x = SA[1 2 3; 4 5 6]
x[1:2, 2:end]
end
@test SMatrix_UnitRange_UnitRange() === SA[2 3; 5 6]
VERSION v"1.1" && @test_const_fold SMatrix_UnitRange_UnitRange()

function SMatrix_StepRange_StepRange()
x = SA[1 2 3; 4 5 6]
x[1:1:2, 1:2:end]
end
@test SMatrix_StepRange_StepRange() === SA[1 3; 4 6]
VERSION v"1.1" && @test_const_fold SMatrix_StepRange_StepRange()
end
end
39 changes: 39 additions & 0 deletions test/testutil.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,45 @@ should_not_be_inlined(x) = _should_not_be_inlined(x)
end


"""
@test_const_fold f(args...)
Test that constant folding works with a function call `f(args...)`.
"""
macro test_const_fold(ex)
quote
ci, = $(esc(:($InteractiveUtils.@code_typed optimize = true $ex)))
@test $(esc(ex)) == constant_return(ci)
end
end

struct NonConstantValue end

function constant_return(ci)
if :rettype in fieldnames(typeof(ci))
ci.rettype isa Core.Compiler.Const && return ci.rettype.val
return NonConstantValue()
else
# for julia < 1.2
ex = ci.code[end]
Meta.isexpr(ex, :return) || return NonConstantValue()
val = ex.args[1]
return val isa QuoteNode ? val.value : val
end
end

@testset "@test_const_fold" begin
should_const_fold() = (1, 2, 3)
@test_const_fold should_const_fold()

x = Ref(1)
should_not_const_fold() = x[]
ts = @testset ErrorCounterTestSet "" begin
@test_const_fold should_not_const_fold()
end
@test ts.errorcount == 0 && ts.failcount == 1 && ts.passcount == 0
end

"""
@inferred_maybe_allow allow ex
Expand Down

0 comments on commit 3745b09

Please sign in to comment.