Skip to content

Commit

Permalink
improve type-based offset axes check (JuliaLang#45260)
Browse files Browse the repository at this point in the history
* Follow up to JuliaLang#45236 (make `length(::StepRange{Int8,Int128})` type-stable)
* Fully drop `_tuple_any` (unneeded now)
* Make sure `has_offset_axes(::StepRange)` could be const folded.
  And define some "cheap" `firstindex`
* Do offset axes check on `A`'s parent rather than itself.
  This avoid some unneeded `axes` call, thus more possible be folded by the compiler.

Co-authored-by: Jameson Nash <[email protected]>
  • Loading branch information
2 people authored and Francesco Fucci committed Aug 11, 2022
1 parent b05a843 commit aa3ebdd
Show file tree
Hide file tree
Showing 10 changed files with 86 additions and 45 deletions.
6 changes: 4 additions & 2 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,11 @@ If multiple arguments are passed, equivalent to `has_offset_axes(A) | has_offset
See also [`require_one_based_indexing`](@ref).
"""
has_offset_axes(A) = _tuple_any(x->Int(first(x))::Int != 1, axes(A))
has_offset_axes(A) = _any_tuple(x->Int(first(x))::Int != 1, false, axes(A)...)
has_offset_axes(A::AbstractVector) = Int(firstindex(A))::Int != 1 # improve performance of a common case (ranges)
has_offset_axes(A...) = _tuple_any(has_offset_axes, A)
# Use `_any_tuple` to avoid unneeded invoke.
# note: this could call `any` directly if the compiler can infer it
has_offset_axes(As...) = _any_tuple(has_offset_axes, false, As...)
has_offset_axes(::Colon) = false

"""
Expand Down
1 change: 1 addition & 0 deletions base/multidimensional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ module IteratorsMD
# AbstractArray implementation
Base.axes(iter::CartesianIndices{N,R}) where {N,R} = map(Base.axes1, iter.indices)
Base.IndexStyle(::Type{CartesianIndices{N,R}}) where {N,R} = IndexCartesian()
Base.has_offset_axes(iter::CartesianIndices) = Base.has_offset_axes(iter.indices...)
# getindex for a 0D CartesianIndices is necessary for disambiguation
@propagate_inbounds function Base.getindex(iter::CartesianIndices{0,R}) where {R}
CartesianIndex()
Expand Down
1 change: 1 addition & 0 deletions base/permuteddimsarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ end
Base.parent(A::PermutedDimsArray) = A.parent
Base.size(A::PermutedDimsArray{T,N,perm}) where {T,N,perm} = genperm(size(parent(A)), perm)
Base.axes(A::PermutedDimsArray{T,N,perm}) where {T,N,perm} = genperm(axes(parent(A)), perm)
Base.has_offset_axes(A::PermutedDimsArray) = Base.has_offset_axes(A.parent)

Base.similar(A::PermutedDimsArray, T::Type, dims::Base.Dims) = similar(parent(A), T, dims)

Expand Down
47 changes: 26 additions & 21 deletions base/range.jl
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,9 @@ step_hp(r::AbstractRange) = step(r)

axes(r::AbstractRange) = (oneto(length(r)),)

# Needed to ensure `has_offset_axes` can constant-fold.
has_offset_axes(::StepRange) = false

# n.b. checked_length for these is defined iff checked_add and checked_sub are
# defined between the relevant types
function checked_length(r::OrdinalRange{T}) where T
Expand Down Expand Up @@ -750,64 +753,66 @@ length(r::OneTo) = Integer(r.stop - zero(r.stop))
length(r::StepRangeLen) = r.len
length(r::LinRange) = r.len

let bigints = Union{Int, UInt, Int64, UInt64, Int128, UInt128}
global length, checked_length
let bigints = Union{Int, UInt, Int64, UInt64, Int128, UInt128},
smallints = (Int === Int64 ?
Union{Int8, UInt8, Int16, UInt16, Int32, UInt32} :
Union{Int8, UInt8, Int16, UInt16}),
bitints = Union{bigints, smallints}
global length, checked_length, firstindex
# compile optimization for which promote_type(T, Int) == T
length(r::OneTo{T}) where {T<:bigints} = r.stop
# slightly more accurate length and checked_length in extreme cases
# (near typemax) for types with known `unsigned` functions
function length(r::OrdinalRange{T}) where T<:bigints
s = step(r)
isempty(r) && return zero(T)
diff = last(r) - first(r)
isempty(r) && return zero(diff)
# if |s| > 1, diff might have overflowed, but unsigned(diff)÷s should
# therefore still be valid (if the result is representable at all)
# n.b. !(s isa T)
if s isa Unsigned || -1 <= s <= 1 || s == -s
a = div(diff, s) % T
a = div(diff, s) % typeof(diff)
elseif s < 0
a = div(unsigned(-diff), -s) % T
a = div(unsigned(-diff), -s) % typeof(diff)
else
a = div(unsigned(diff), s) % T
a = div(unsigned(diff), s) % typeof(diff)
end
return a + oneunit(T)
return a + oneunit(a)
end
function checked_length(r::OrdinalRange{T}) where T<:bigints
s = step(r)
isempty(r) && return zero(T)
stop, start = last(r), first(r)
ET = promote_type(typeof(stop), typeof(start))
isempty(r) && return zero(ET)
# n.b. !(s isa T)
if s > 1
diff = stop - start
a = convert(T, div(unsigned(diff), s))
a = convert(ET, div(unsigned(diff), s))
elseif s < -1
diff = start - stop
a = convert(T, div(unsigned(diff), -s))
a = convert(ET, div(unsigned(diff), -s))
elseif s > 0
a = div(checked_sub(stop, start), s)
a = convert(ET, div(checked_sub(stop, start), s))
else
a = div(checked_sub(start, stop), -s)
a = convert(ET, div(checked_sub(start, stop), -s))
end
return checked_add(convert(T, a), oneunit(T))
return checked_add(a, oneunit(a))
end
end
firstindex(r::StepRange{<:bigints,<:bitints}) = one(last(r)-first(r))

# some special cases to favor default Int type
let smallints = (Int === Int64 ?
Union{Int8, UInt8, Int16, UInt16, Int32, UInt32} :
Union{Int8, UInt8, Int16, UInt16})
global length, checked_length
# n.b. !(step isa T)
# some special cases to favor default Int type
function length(r::OrdinalRange{<:smallints})
s = step(r)
isempty(r) && return 0
return div(Int(last(r)) - Int(first(r)), s) + 1
# n.b. !(step isa T)
return Int(div(Int(last(r)) - Int(first(r)), s)) + 1
end
length(r::AbstractUnitRange{<:smallints}) = Int(last(r)) - Int(first(r)) + 1
length(r::OneTo{<:smallints}) = Int(r.stop)
checked_length(r::OrdinalRange{<:smallints}) = length(r)
checked_length(r::AbstractUnitRange{<:smallints}) = length(r)
checked_length(r::OneTo{<:smallints}) = length(r)
firstindex(::StepRange{<:smallints,<:bitints}) = 1
end

first(r::OrdinalRange{T}) where {T} = convert(T, r.start)
Expand Down
2 changes: 2 additions & 0 deletions base/reinterpretarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,8 @@ function axes(a::ReshapedReinterpretArray{T,N,S} where {N}) where {T,S}
end
axes(a::NonReshapedReinterpretArray{T,0}) where {T} = ()

has_offset_axes(a::ReinterpretArray) = has_offset_axes(a.parent)

elsize(::Type{<:ReinterpretArray{T}}) where {T} = sizeof(T)
unsafe_convert(::Type{Ptr{T}}, a::ReinterpretArray{T,N,S} where N) where {T,S} = Ptr{T}(unsafe_convert(Ptr{S},a.parent))

Expand Down
2 changes: 2 additions & 0 deletions base/subarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -459,3 +459,5 @@ function _indices_sub(i1::AbstractArray, I...)
@inline
(axes(i1)..., _indices_sub(I...)...)
end

has_offset_axes(S::SubArray) = has_offset_axes(S.indices...)
9 changes: 0 additions & 9 deletions base/tuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -581,15 +581,6 @@ any(x::Tuple{Bool}) = x[1]
any(x::Tuple{Bool, Bool}) = x[1]|x[2]
any(x::Tuple{Bool, Bool, Bool}) = x[1]|x[2]|x[3]

# equivalent to any(f, t), to be used only in bootstrap
_tuple_any(f::Function, t::Tuple) = _tuple_any(f, false, t...)
function _tuple_any(f::Function, tf::Bool, a, b...)
@inline
_tuple_any(f, tf | f(a), b...)
end
_tuple_any(f::Function, tf::Bool) = tf


# a version of `in` esp. for NamedTuple, to make it pure, and not compiled for each tuple length
function sym_in(x::Symbol, @nospecialize itr::Tuple{Vararg{Symbol}})
@_total_meta
Expand Down
33 changes: 29 additions & 4 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1583,11 +1583,11 @@ end
@test length(rr) == length(r)
end

struct FakeZeroDimArray <: AbstractArray{Int, 0} end
Base.strides(::FakeZeroDimArray) = ()
Base.size(::FakeZeroDimArray) = ()
module IRUtils
include("compiler/irutils.jl")
end

@testset "strides for ReshapedArray" begin
# Type-based contiguous check is tested in test/compiler/inline.jl
function check_strides(A::AbstractArray)
# Make sure stride(A, i) is equivalent with strides(A)[i] (if 1 <= i <= ndims(A))
dims = ntuple(identity, ndims(A))
Expand All @@ -1598,6 +1598,10 @@ Base.size(::FakeZeroDimArray) = ()
end
return true
end
# Type-based contiguous Check
a = vec(reinterpret(reshape, Int16, reshape(view(reinterpret(Int32, randn(10)), 2:11), 5, :)))
f(a) = only(strides(a));
@test IRUtils.fully_eliminated(f, Base.typesof(a)) && f(a) == 1
# General contiguous check
a = view(rand(10,10), 1:10, 1:10)
@test check_strides(vec(a))
Expand Down Expand Up @@ -1629,6 +1633,9 @@ Base.size(::FakeZeroDimArray) = ()
@test_throws "Input is not strided." strides(reshape(a,3,5,3,2))
@test_throws "Input is not strided." strides(reshape(a,5,3,3,2))
# Zero dimensional parent
struct FakeZeroDimArray <: AbstractArray{Int, 0} end
Base.strides(::FakeZeroDimArray) = ()
Base.size(::FakeZeroDimArray) = ()
a = reshape(FakeZeroDimArray(),1,1,1)
@test @inferred(strides(a)) == (1, 1, 1)
# Dense parent (but not StridedArray)
Expand Down Expand Up @@ -1660,3 +1667,21 @@ end
@test (@inferred A[i,i,i]) === A[1]
@test (@inferred to_indices([], (1, CIdx(1, 1), 1, CIdx(1, 1), 1, CIdx(1, 1), 1))) == ntuple(Returns(1), 10)
end

@testset "type-based offset axes check" begin
a = randn(ComplexF64, 10)
ta = reinterpret(Float64, a)
tb = reinterpret(Float64, view(a, 1:2:10))
tc = reinterpret(Float64, reshape(view(a, 1:3:10), 2, 2, 1))
# Issue #44040
@test IRUtils.fully_eliminated(Base.require_one_based_indexing, Base.typesof(ta, tc))
@test IRUtils.fully_eliminated(Base.require_one_based_indexing, Base.typesof(tc, tc))
@test IRUtils.fully_eliminated(Base.require_one_based_indexing, Base.typesof(ta, tc, tb))
# Ranges && CartesianIndices
@test IRUtils.fully_eliminated(Base.require_one_based_indexing, Base.typesof(1:10, Base.OneTo(10), 1.0:2.0, LinRange(1.0, 2.0, 2), 1:2:10, CartesianIndices((1:2:10, 1:2:10))))
# Remind us to call `any` in `Base.has_offset_axes` once our compiler is ready.
@inline _has_offset_axes(A) = @inline any(x -> Int(first(x))::Int != 1, axes(A))
@inline _has_offset_axes(As...) = @inline any(_has_offset_axes, As)
a, b = zeros(2, 2, 2), zeros(2, 2)
@test_broken IRUtils.fully_eliminated(_has_offset_axes, Base.typesof(a, a, b, b))
end
7 changes: 0 additions & 7 deletions test/compiler/inline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -988,13 +988,6 @@ end
@invoke conditional_escape!(false::Any, x::Any)
end

@testset "strides for ReshapedArray (PR#44027)" begin
# Type-based contiguous check
a = vec(reinterpret(reshape,Int16,reshape(view(reinterpret(Int32,randn(10)),2:11),5,:)))
f(a) = only(strides(a));
@test fully_eliminated(f, Tuple{typeof(a)}) && f(a) == 1
end

@testset "elimination of `get_binding_type`" begin
m = Module()
@eval m begin
Expand Down
23 changes: 21 additions & 2 deletions test/ranges.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2031,8 +2031,17 @@ end
end

@testset "length(StepRange()) type stability" begin
typeof(length(StepRange(1,Int128(1),1))) == typeof(length(StepRange(1,Int128(1),0)))
typeof(checked_length(StepRange(1,Int128(1),1))) == typeof(checked_length(StepRange(1,Int128(1),0)))
for SR in (StepRange{Int,Int128}, StepRange{Int8,Int128})
r1, r2 = SR(1, 1, 1), SR(1, 1, 0)
@test typeof(length(r1)) == typeof(checked_length(r1)) ==
typeof(length(r2)) == typeof(checked_length(r2))
end
SR = StepRange{Union{Int64,Int128},Int}
test_length(r, l) = length(r) === checked_length(r) === l
@test test_length(SR(Int64(1), 1, Int128(1)), Int128(1))
@test test_length(SR(Int64(1), 1, Int128(0)), Int128(0))
@test test_length(SR(Int64(1), 1, Int64(1)), Int64(1))
@test test_length(SR(Int64(1), 1, Int64(0)), Int64(0))
end

@testset "LinRange eltype for element types that wrap integers" begin
Expand Down Expand Up @@ -2346,3 +2355,13 @@ end
@test isempty(range(typemax(Int), length=0, step=UInt(2)))

@test length(range(1, length=typemax(Int128))) === typemax(Int128)

@testset "firstindex(::StepRange{<:Base.BitInteger})" begin
test_firstindex(x) = firstindex(x) === first(Base.axes1(x))
for T in Base.BitInteger_types, S in Base.BitInteger_types
@test test_firstindex(StepRange{T,S}(1, 1, 1))
@test test_firstindex(StepRange{T,S}(1, 1, 0))
end
@test test_firstindex(StepRange{Union{Int64,Int128},Int}(Int64(1), 1, Int128(1)))
@test test_firstindex(StepRange{Union{Int64,Int128},Int}(Int64(1), 1, Int128(0)))
end

0 comments on commit aa3ebdd

Please sign in to comment.