From 6fb355885e96c40af119f7d360945021ab03f7d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bogumi=C5=82=20Kami=C5=84ski?= Date: Mon, 8 Mar 2021 09:46:40 +0100 Subject: [PATCH] Range indexing: error with scalar bool index like all other arrays (#31829) --- NEWS.md | 3 + base/range.jl | 132 ++++++++++++++++++++++++++++++------ base/twiceprecision.jl | 40 +++++++---- test/ranges.jl | 147 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 289 insertions(+), 33 deletions(-) diff --git a/NEWS.md b/NEWS.md index ec62064cf0754..b692db7caa6d7 100644 --- a/NEWS.md +++ b/NEWS.md @@ -51,6 +51,9 @@ Standard library changes * `escape_string` can now receive a collection of characters in the keyword `keep` that are to be kept as they are. ([#38597]). * `getindex` can now be used on `NamedTuple`s with multiple values ([#38878]) +* Subtypes of `AbstractRange` now correctly follow the general array indexing + behavior when indexed by `Bool`s, erroring for scalar `Bool`s and treating + arrays (including ranges) of `Bool` as an logical index ([#31829]) * `keys(::RegexMatch)` is now defined to return the capture's keys, by name if named, or by index if not ([#37299]). * `keys(::Generator)` is now defined to return the iterator's keys ([#34678]) * `RegexMatch` now iterate to give their captures. ([#34355]). diff --git a/base/range.jl b/base/range.jl index 7278d8dc61e9b..54aeef84cae19 100644 --- a/base/range.jl +++ b/base/range.jl @@ -392,12 +392,19 @@ be 1. """ struct OneTo{T<:Integer} <: AbstractUnitRange{T} stop::T - OneTo{T}(stop) where {T<:Integer} = new(max(zero(T), stop)) + function OneTo{T}(stop) where {T<:Integer} + throwbool(r) = (@_noinline_meta; throw(ArgumentError("invalid index: $r of type Bool"))) + T === Bool && throwbool(stop) + return new(max(zero(T), stop)) + end + function OneTo{T}(r::AbstractRange) where {T<:Integer} throwstart(r) = (@_noinline_meta; throw(ArgumentError("first element must be 1, got $(first(r))"))) throwstep(r) = (@_noinline_meta; throw(ArgumentError("step must be 1, got $(step(r))"))) + throwbool(r) = (@_noinline_meta; throw(ArgumentError("invalid index: $r of type Bool"))) first(r) == 1 || throwstart(r) step(r) == 1 || throwstep(r) + T === Bool && throwbool(r) return new(max(zero(T), last(r))) end end @@ -748,6 +755,7 @@ _in_unit_range(v::UnitRange, val, i::Integer) = i > 0 && val <= v.stop && val >= function getindex(v::UnitRange{T}, i::Integer) where T @_inline_meta + i isa Bool && throw(ArgumentError("invalid index: $i of type Bool")) val = convert(T, v.start + (i - 1)) @boundscheck _in_unit_range(v, val, i) || throw_boundserror(v, i) val @@ -758,6 +766,7 @@ const OverflowSafe = Union{Bool,Int8,Int16,Int32,Int64,Int128, function getindex(v::UnitRange{T}, i::Integer) where {T<:OverflowSafe} @_inline_meta + i isa Bool && throw(ArgumentError("invalid index: $i of type Bool")) val = v.start + (i - 1) @boundscheck _in_unit_range(v, val, i) || throw_boundserror(v, i) val % T @@ -765,12 +774,14 @@ end function getindex(v::OneTo{T}, i::Integer) where T @_inline_meta + i isa Bool && throw(ArgumentError("invalid index: $i of type Bool")) @boundscheck ((i > 0) & (i <= v.stop)) || throw_boundserror(v, i) convert(T, i) end function getindex(v::AbstractRange{T}, i::Integer) where T @_inline_meta + i isa Bool && throw(ArgumentError("invalid index: $i of type Bool")) ret = convert(T, first(v) + (i - 1)*step_hp(v)) ok = ifelse(step(v) > zero(step(v)), (ret <= last(v)) & (ret >= first(v)), @@ -781,22 +792,26 @@ end function getindex(r::Union{StepRangeLen,LinRange}, i::Integer) @_inline_meta + i isa Bool && throw(ArgumentError("invalid index: $i of type Bool")) @boundscheck checkbounds(r, i) unsafe_getindex(r, i) end # This is separate to make it useful even when running with --check-bounds=yes function unsafe_getindex(r::StepRangeLen{T}, i::Integer) where T + i isa Bool && throw(ArgumentError("invalid index: $i of type Bool")) u = i - r.offset T(r.ref + u*r.step) end function _getindex_hiprec(r::StepRangeLen, i::Integer) # without rounding by T + i isa Bool && throw(ArgumentError("invalid index: $i of type Bool")) u = i - r.offset r.ref + u*r.step end function unsafe_getindex(r::LinRange, i::Integer) + i isa Bool && throw(ArgumentError("invalid index: $i of type Bool")) lerpi(i-1, r.lendiv, r.start, r.stop) end @@ -808,12 +823,27 @@ end getindex(r::AbstractRange, ::Colon) = copy(r) -function getindex(r::AbstractUnitRange, s::AbstractUnitRange{<:Integer}) +function getindex(r::AbstractUnitRange, s::AbstractUnitRange{T}) where {T<:Integer} @_inline_meta @boundscheck checkbounds(r, s) - f = first(r) - st = oftype(f, f + first(s)-1) - range(st, length=length(s)) + + if T === Bool + if length(s) == 0 + return r + elseif length(s) == 1 + if first(s) + return r + else + return range(r[1], length=0) + end + else # length(s) == 2 + return range(r[2], length=1) + end + else + f = first(r) + st = oftype(f, f + first(s)-1) + return range(st, length=length(s)) + end end function getindex(r::OneTo{T}, s::OneTo) where T @@ -822,36 +852,96 @@ function getindex(r::OneTo{T}, s::OneTo) where T OneTo(T(s.stop)) end -function getindex(r::AbstractUnitRange, s::StepRange{<:Integer}) +function getindex(r::AbstractUnitRange, s::StepRange{T}) where {T<:Integer} @_inline_meta @boundscheck checkbounds(r, s) - st = oftype(first(r), first(r) + s.start-1) - range(st, step=step(s), length=length(s)) + + if T === Bool + if length(s) == 0 + return range(first(r), step=one(eltype(r)), length=0) + elseif length(s) == 1 + if first(s) + return range(first(r), step=one(eltype(r)), length=1) + else + return range(first(r), step=one(eltype(r)), length=0) + end + else # length(s) == 2 + return range(r[2], step=one(eltype(r)), length=1) + end + else + st = oftype(first(r), first(r) + s.start-1) + return range(st, step=step(s), length=length(s)) + end end -function getindex(r::StepRange, s::AbstractRange{<:Integer}) +function getindex(r::StepRange, s::AbstractRange{T}) where {T<:Integer} @_inline_meta @boundscheck checkbounds(r, s) - st = oftype(r.start, r.start + (first(s)-1)*step(r)) - range(st, step=step(r)*step(s), length=length(s)) + + if T === Bool + if length(s) == 0 + return range(first(r), step=step(r), length=0) + elseif length(s) == 1 + if first(s) + return range(first(r), step=step(r), length=1) + else + return range(first(r), step=step(r), length=0) + end + else # length(s) == 2 + return range(r[2], step=step(r), length=1) + end + else + st = oftype(r.start, r.start + (first(s)-1)*step(r)) + return range(st, step=step(r)*step(s), length=length(s)) + end end -function getindex(r::StepRangeLen{T}, s::OrdinalRange{<:Integer}) where {T} +function getindex(r::StepRangeLen{T}, s::OrdinalRange{S}) where {T, S<:Integer} @_inline_meta @boundscheck checkbounds(r, s) - # Find closest approach to offset by s - ind = LinearIndices(s) - offset = max(min(1 + round(Int, (r.offset - first(s))/step(s)), last(ind)), first(ind)) - ref = _getindex_hiprec(r, first(s) + (offset-1)*step(s)) - return StepRangeLen{T}(ref, r.step*step(s), length(s), offset) + + if S === Bool + if length(s) == 0 + return StepRangeLen{T}(first(r), step(r), 0, 1) + elseif length(s) == 1 + if first(s) + return StepRangeLen{T}(first(r), step(r), 1, 1) + else + return StepRangeLen{T}(first(r), step(r), 0, 1) + end + else # length(s) == 2 + return StepRangeLen{T}(r[2], step(r), 1, 1) + end + else + # Find closest approach to offset by s + ind = LinearIndices(s) + offset = max(min(1 + round(Int, (r.offset - first(s))/step(s)), last(ind)), first(ind)) + ref = _getindex_hiprec(r, first(s) + (offset-1)*step(s)) + return StepRangeLen{T}(ref, r.step*step(s), length(s), offset) + end end -function getindex(r::LinRange{T}, s::OrdinalRange{<:Integer}) where {T} +function getindex(r::LinRange{T}, s::OrdinalRange{S}) where {T, S<:Integer} @_inline_meta @boundscheck checkbounds(r, s) - vfirst = unsafe_getindex(r, first(s)) - vlast = unsafe_getindex(r, last(s)) - return LinRange{T}(vfirst, vlast, length(s)) + + if S === Bool + if length(s) == 0 + return LinRange(first(r), first(r), 0) + elseif length(s) == 1 + if first(s) + return LinRange(first(r), first(r), 1) + else + return LinRange(first(r), first(r), 0) + end + else # length(s) == 2 + return LinRange(r[2], r[2], 1) + end + else + vfirst = unsafe_getindex(r, first(s)) + vlast = unsafe_getindex(r, last(s)) + return LinRange{T}(vfirst, vlast, length(s)) + end end show(io::IO, r::AbstractRange) = print(io, repr(first(r)), ':', repr(step(r)), ':', repr(last(r))) diff --git a/base/twiceprecision.jl b/base/twiceprecision.jl index e7a2e5041f4ef..4df9c072f78ee 100644 --- a/base/twiceprecision.jl +++ b/base/twiceprecision.jl @@ -448,6 +448,7 @@ end function unsafe_getindex(r::StepRangeLen{T,<:TwicePrecision,<:TwicePrecision}, i::Integer) where T # Very similar to _getindex_hiprec, but optimized to avoid a 2nd call to add12 @_inline_meta + i isa Bool && throw(ArgumentError("invalid index: $i of type Bool")) u = i - r.offset shift_hi, shift_lo = u*r.step.hi, u*r.step.lo x_hi, x_lo = add12(r.ref.hi, shift_hi) @@ -455,6 +456,7 @@ function unsafe_getindex(r::StepRangeLen{T,<:TwicePrecision,<:TwicePrecision}, i end function _getindex_hiprec(r::StepRangeLen{<:Any,<:TwicePrecision,<:TwicePrecision}, i::Integer) + i isa Bool && throw(ArgumentError("invalid index: $i of type Bool")) u = i - r.offset shift_hi, shift_lo = u*r.step.hi, u*r.step.lo x_hi, x_lo = add12(r.ref.hi, shift_hi) @@ -462,20 +464,34 @@ function _getindex_hiprec(r::StepRangeLen{<:Any,<:TwicePrecision,<:TwicePrecisio TwicePrecision(x_hi, x_lo) end -function getindex(r::StepRangeLen{T,<:TwicePrecision,<:TwicePrecision}, s::OrdinalRange{<:Integer}) where T +function getindex(r::StepRangeLen{T,<:TwicePrecision,<:TwicePrecision}, s::OrdinalRange{S}) where {T, S<:Integer} @boundscheck checkbounds(r, s) - soffset = 1 + round(Int, (r.offset - first(s))/step(s)) - soffset = clamp(soffset, 1, length(s)) - ioffset = first(s) + (soffset-1)*step(s) - if step(s) == 1 || length(s) < 2 - newstep = r.step - else - newstep = twiceprecision(r.step*step(s), nbitslen(T, length(s), soffset)) - end - if ioffset == r.offset - StepRangeLen(r.ref, newstep, length(s), max(1,soffset)) + if S === Bool + if length(s) == 0 + return StepRangeLen(r.ref, r.step, 0, 1) + elseif length(s) == 1 + if first(s) + return StepRangeLen(r.ref, r.step, 1, 1) + else + return StepRangeLen(r.ref, r.step, 0, 1) + end + else # length(s) == 2 + return StepRangeLen(r[2], step(r), 1, 1) + end else - StepRangeLen(r.ref + (ioffset-r.offset)*r.step, newstep, length(s), max(1,soffset)) + soffset = 1 + round(Int, (r.offset - first(s))/step(s)) + soffset = clamp(soffset, 1, length(s)) + ioffset = first(s) + (soffset-1)*step(s) + if step(s) == 1 || length(s) < 2 + newstep = r.step + else + newstep = twiceprecision(r.step*step(s), nbitslen(T, length(s), soffset)) + end + if ioffset == r.offset + return StepRangeLen(r.ref, newstep, length(s), max(1,soffset)) + else + return StepRangeLen(r.ref + (ioffset-r.offset)*r.step, newstep, length(s), max(1,soffset)) + end end end diff --git a/test/ranges.jl b/test/ranges.jl index 937aabe471561..fd12a0829ce5e 100644 --- a/test/ranges.jl +++ b/test/ranges.jl @@ -1752,3 +1752,150 @@ end @test eltype(StepRangeLen(Int8(1), Int8(2), 3, 2)) === Int8 @test typeof(step(StepRangeLen(Int8(1), Int8(2), 3, 2))) === Int8 end + +@testset "Bool indexing of ranges" begin + @test_throws ArgumentError Base.OneTo(true) + @test_throws ArgumentError Base.OneTo(true:true:true) + + @test_throws ArgumentError (1:2)[true] + @test_throws ArgumentError (big(1):big(2))[true] + @test_throws ArgumentError Base.OneTo(10)[true] + @test_throws ArgumentError (1:2:5)[true] + @test_throws ArgumentError LinRange(1,2,2)[true] + @test_throws ArgumentError (1.0:2.0:5.0)[true] + r = 3:2 + r2 = r[true:false] + @test r2 == collect(r)[true:false] + @test r.start == r2.start && r.stop == r2.stop + @test_throws BoundsError r[true:true] + @test_throws BoundsError r[false:true] + r = 3:3 + r2 = r[true:true] + @test r2 == collect(r)[true:true] + @test r.start == r2.start && r.stop == r2.stop + r2 = r[false:false] + @test r2.start == 3 && r2.stop == 2 + @test_throws BoundsError r[true:false] + @test_throws BoundsError r[false:true] + r = 2:3 + r2 = r[false:true] + @test r2 == collect(r)[false:true] + @test r2.start == r2.stop == 3 + @test_throws BoundsError r[true:false] + @test_throws BoundsError r[true:true] + + r = 2:1 + r2 = r[true:true:false] + @test r2 == collect(r)[true:true:false] + @test r2 isa StepRange && r2.start == 2 && r2.step == 1 && r2.stop == 1 + @test_throws BoundsError r[false:true:false] + + r = 2:2 + r2 = r[false:true:false] + @test r2 == collect(r)[false:true:false] + @test r2 isa StepRange && r2.start == 2 && r2.step == 1 && r2.stop == 1 + r2 = r[true:true:true] + @test r2 == collect(r)[true:true:true] + @test r2 isa StepRange && r2.start == 2 && r2.step == 1 && r2.stop == 2 + @test_throws BoundsError r[true:true:false] + @test_throws BoundsError r[false:true:true] + + r = 1:2 + r2 = r[false:true:true] + @test r2 == collect(r)[false:true:true] + @test r2 isa StepRange && r2.start == 2 && r2.step == 1 && r2.stop == 2 + @test_throws BoundsError r[true:true:false] + @test_throws BoundsError r[true:true:true] + + r = 2:1:1 + r2 = r[true:true:false] + @test r2 == collect(r)[true:true:false] + @test r2 isa StepRange && r2.start == 2 && r2.step == 1 && r2.stop == 1 + @test_throws BoundsError r[false:true:false] + + r = 2:1:2 + r2 = r[false:true:false] + @test r2 == collect(r)[false:true:false] + @test r2 isa StepRange && r2.start == 2 && r2.step == 1 && r2.stop == 1 + r2 = r[true:true:true] + @test r2 == collect(r)[true:true:true] + @test r2 isa StepRange && r2.start == 2 && r2.step == 1 && r2.stop == 2 + @test_throws BoundsError r[true:true:false] + @test_throws BoundsError r[false:true:true] + + r = 1:1:2 + r2 = r[false:true:true] + @test r2 == collect(r)[false:true:true] + @test r2 isa StepRange && r2.start == 2 && r2.step == 1 && r2.stop == 2 + @test_throws BoundsError r[true:true:false] + @test_throws BoundsError r[true:true:true] + + r = 2.0:1.0:1.0 + r2 = r[true:true:false] + @test r2 == collect(r)[true:true:false] + @test r2 isa StepRangeLen && r2 == 2:1 + @test_throws BoundsError r[false:true:false] + + r = 2.0:1.0:2.0 + r2 = r[false:true:false] + @test r2 == collect(r)[false:true:false] + @test r2 isa StepRangeLen && r2 == 2:1 + r2 = r[true:true:true] + @test r2 == collect(r)[true:true:true] + @test r2 isa StepRangeLen && r2 == 2:2 + @test_throws BoundsError r[true:true:false] + @test_throws BoundsError r[false:true:true] + + r = 1.0:1.0:2.0 + r2 = r[false:true:true] + @test r2 == collect(r)[false:true:true] + @test r2 isa StepRangeLen && r2 == 2:2 + @test_throws BoundsError r[true:true:false] + @test_throws BoundsError r[true:true:true] + + r = StepRangeLen(2, 1, 0) + r2 = r[true:true:false] + @test r2 == collect(r)[true:true:false] + @test r2 isa StepRangeLen && r2 == 2:1 + @test_throws BoundsError r[false:true:false] + + r = StepRangeLen(2, 1, 1) + r2 = r[false:true:false] + @test r2 == collect(r)[false:true:false] + @test r2 isa StepRangeLen && r2 == 2:1 + r2 = r[true:true:true] + @test r2 == collect(r)[true:true:true] + @test r2 isa StepRangeLen && r2 == 2:2 + @test_throws BoundsError r[true:true:false] + @test_throws BoundsError r[false:true:true] + + r = StepRangeLen(1, 1, 2) + r2 = r[false:true:true] + @test r2 == collect(r)[false:true:true] + @test r2 isa StepRangeLen && r2 == 2:2 + @test_throws BoundsError r[true:true:false] + @test_throws BoundsError r[true:true:true] + + r = LinRange(2, 1, 0) + r2 = r[true:true:false] + @test r2 == collect(r)[true:true:false] + @test r2 isa LinRange && r2 == 2:1 + @test_throws BoundsError r[false:true:false] + + r = LinRange(2, 2, 1) + r2 = r[false:true:false] + @test r2 == collect(r)[false:true:false] + @test r2 isa LinRange && r2 == 2:1 + r2 = r[true:true:true] + @test r2 == collect(r)[true:true:true] + @test r2 isa LinRange && r2 == 2:2 + @test_throws BoundsError r[true:true:false] + @test_throws BoundsError r[false:true:true] + + r = LinRange(1, 2, 2) + r2 = r[false:true:true] + @test r2 == collect(r)[false:true:true] + @test r2 isa LinRange && r2 == 2:2 + @test_throws BoundsError r[true:true:false] + @test_throws BoundsError r[true:true:true] +end