Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Range indexing: error with scalar bool index like all other arrays #31829

Merged
merged 25 commits into from
Mar 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ Standard library changes
* `escape_string` can now receive a collection of characters in the keyword
`keep` that are to be kept as they are. ([#38597]).
* `getindex` can now be used on `NamedTuple`s with multiple values ([#38878])
* Subtypes of `AbstractRange` now correctly follow the general array indexing
behavior when indexed by `Bool`s, erroring for scalar `Bool`s and treating
arrays (including ranges) of `Bool` as an logical index ([#31829])
* `keys(::RegexMatch)` is now defined to return the capture's keys, by name if named, or by index if not ([#37299]).
* `RegexMatch` now iterate to give their captures. ([#34355]).

Expand Down
132 changes: 111 additions & 21 deletions base/range.jl
Original file line number Diff line number Diff line change
Expand Up @@ -406,12 +406,19 @@ be 1.
"""
struct OneTo{T<:Integer} <: AbstractUnitRange{T}
stop::T
OneTo{T}(stop) where {T<:Integer} = new(max(zero(T), stop))
function OneTo{T}(stop) where {T<:Integer}
throwbool(r) = (@_noinline_meta; throw(ArgumentError("invalid index: $r of type Bool")))
T === Bool && throwbool(stop)
return new(max(zero(T), stop))
end

function OneTo{T}(r::AbstractRange) where {T<:Integer}
throwstart(r) = (@_noinline_meta; throw(ArgumentError("first element must be 1, got $(first(r))")))
throwstep(r) = (@_noinline_meta; throw(ArgumentError("step must be 1, got $(step(r))")))
throwbool(r) = (@_noinline_meta; throw(ArgumentError("invalid index: $r of type Bool")))
first(r) == 1 || throwstart(r)
step(r) == 1 || throwstep(r)
T === Bool && throwbool(r)
return new(max(zero(T), last(r)))
end
end
Expand Down Expand Up @@ -762,6 +769,7 @@ _in_unit_range(v::UnitRange, val, i::Integer) = i > 0 && val <= v.stop && val >=

function getindex(v::UnitRange{T}, i::Integer) where T
@_inline_meta
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
val = convert(T, v.start + (i - 1))
@boundscheck _in_unit_range(v, val, i) || throw_boundserror(v, i)
val
Expand All @@ -772,19 +780,22 @@ const OverflowSafe = Union{Bool,Int8,Int16,Int32,Int64,Int128,

function getindex(v::UnitRange{T}, i::Integer) where {T<:OverflowSafe}
@_inline_meta
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
val = v.start + (i - 1)
@boundscheck _in_unit_range(v, val, i) || throw_boundserror(v, i)
val % T
end

function getindex(v::OneTo{T}, i::Integer) where T
@_inline_meta
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
@boundscheck ((i > 0) & (i <= v.stop)) || throw_boundserror(v, i)
convert(T, i)
end

function getindex(v::AbstractRange{T}, i::Integer) where T
@_inline_meta
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
ret = convert(T, first(v) + (i - 1)*step_hp(v))
ok = ifelse(step(v) > zero(step(v)),
(ret <= last(v)) & (ret >= first(v)),
Expand All @@ -795,22 +806,26 @@ end

function getindex(r::Union{StepRangeLen,LinRange}, i::Integer)
@_inline_meta
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
@boundscheck checkbounds(r, i)
unsafe_getindex(r, i)
end

# This is separate to make it useful even when running with --check-bounds=yes
function unsafe_getindex(r::StepRangeLen{T}, i::Integer) where T
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
u = i - r.offset
T(r.ref + u*r.step)
end

function _getindex_hiprec(r::StepRangeLen, i::Integer) # without rounding by T
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
u = i - r.offset
r.ref + u*r.step
end

function unsafe_getindex(r::LinRange, i::Integer)
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
lerpi(i-1, r.lendiv, r.start, r.stop)
end

Expand All @@ -822,12 +837,27 @@ end

getindex(r::AbstractRange, ::Colon) = copy(r)

function getindex(r::AbstractUnitRange, s::AbstractUnitRange{<:Integer})
function getindex(r::AbstractUnitRange, s::AbstractUnitRange{T}) where {T<:Integer}
@_inline_meta
@boundscheck checkbounds(r, s)
f = first(r)
st = oftype(f, f + first(s)-1)
range(st, length=length(s))

if T === Bool
if length(s) == 0
return r
elseif length(s) == 1
if first(s)
return r
else
return range(r[1], length=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return range(r[1], length=0)
return range(first(r), length=0)

In case someone makes an OffsetUnitRange?

end
else # length(s) == 2
return range(r[2], length=1)
end
else
f = first(r)
st = oftype(f, f + first(s)-1)
return range(st, length=length(s))
end
end

function getindex(r::OneTo{T}, s::OneTo) where T
Expand All @@ -836,36 +866,96 @@ function getindex(r::OneTo{T}, s::OneTo) where T
OneTo(T(s.stop))
end

function getindex(r::AbstractUnitRange, s::StepRange{<:Integer})
function getindex(r::AbstractUnitRange, s::StepRange{T}) where {T<:Integer}
@_inline_meta
@boundscheck checkbounds(r, s)
st = oftype(first(r), first(r) + s.start-1)
range(st, step=step(s), length=length(s))

if T === Bool
if length(s) == 0
return range(first(r), step=one(eltype(r)), length=0)
elseif length(s) == 1
if first(s)
return range(first(r), step=one(eltype(r)), length=1)
else
return range(first(r), step=one(eltype(r)), length=0)
end
else # length(s) == 2
return range(r[2], step=one(eltype(r)), length=1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return range(r[2], step=one(eltype(r)), length=1)
step = one(eltype(r))
return range(first(r) + step; step, length=1)

end
else
st = oftype(first(r), first(r) + s.start-1)
return range(st, step=step(s), length=length(s))
end
end

function getindex(r::StepRange, s::AbstractRange{<:Integer})
function getindex(r::StepRange, s::AbstractRange{T}) where {T<:Integer}
@_inline_meta
@boundscheck checkbounds(r, s)
st = oftype(r.start, r.start + (first(s)-1)*step(r))
range(st, step=step(r)*step(s), length=length(s))

if T === Bool
if length(s) == 0
return range(first(r), step=step(r), length=0)
elseif length(s) == 1
if first(s)
return range(first(r), step=step(r), length=1)
else
return range(first(r), step=step(r), length=0)
end
else # length(s) == 2
return range(r[2], step=step(r), length=1)
end
else
st = oftype(r.start, r.start + (first(s)-1)*step(r))
return range(st, step=step(r)*step(s), length=length(s))
end
end

function getindex(r::StepRangeLen{T}, s::OrdinalRange{<:Integer}) where {T}
function getindex(r::StepRangeLen{T}, s::OrdinalRange{S}) where {T, S<:Integer}
@_inline_meta
@boundscheck checkbounds(r, s)
# Find closest approach to offset by s
ind = LinearIndices(s)
offset = max(min(1 + round(Int, (r.offset - first(s))/step(s)), last(ind)), first(ind))
ref = _getindex_hiprec(r, first(s) + (offset-1)*step(s))
return StepRangeLen{T}(ref, r.step*step(s), length(s), offset)

if S === Bool
if length(s) == 0
return StepRangeLen{T}(first(r), step(r), 0, 1)
elseif length(s) == 1
if first(s)
return StepRangeLen{T}(first(r), step(r), 1, 1)
else
return StepRangeLen{T}(first(r), step(r), 0, 1)
end
else # length(s) == 2
return StepRangeLen{T}(r[2], step(r), 1, 1)
end
else
# Find closest approach to offset by s
ind = LinearIndices(s)
offset = max(min(1 + round(Int, (r.offset - first(s))/step(s)), last(ind)), first(ind))
ref = _getindex_hiprec(r, first(s) + (offset-1)*step(s))
return StepRangeLen{T}(ref, r.step*step(s), length(s), offset)
end
end

function getindex(r::LinRange{T}, s::OrdinalRange{<:Integer}) where {T}
function getindex(r::LinRange{T}, s::OrdinalRange{S}) where {T, S<:Integer}
@_inline_meta
@boundscheck checkbounds(r, s)
vfirst = unsafe_getindex(r, first(s))
vlast = unsafe_getindex(r, last(s))
return LinRange{T}(vfirst, vlast, length(s))

if S === Bool
if length(s) == 0
return LinRange(first(r), first(r), 0)
elseif length(s) == 1
if first(s)
return LinRange(first(r), first(r), 1)
else
return LinRange(first(r), first(r), 0)
end
else # length(s) == 2
return LinRange(r[2], r[2], 1)
end
else
vfirst = unsafe_getindex(r, first(s))
vlast = unsafe_getindex(r, last(s))
return LinRange{T}(vfirst, vlast, length(s))
end
end

show(io::IO, r::AbstractRange) = print(io, repr(first(r)), ':', repr(step(r)), ':', repr(last(r)))
Expand Down
40 changes: 28 additions & 12 deletions base/twiceprecision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -448,34 +448,50 @@ end
function unsafe_getindex(r::StepRangeLen{T,<:TwicePrecision,<:TwicePrecision}, i::Integer) where T
# Very similar to _getindex_hiprec, but optimized to avoid a 2nd call to add12
@_inline_meta
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
u = i - r.offset
shift_hi, shift_lo = u*r.step.hi, u*r.step.lo
x_hi, x_lo = add12(r.ref.hi, shift_hi)
T(x_hi + (x_lo + (shift_lo + r.ref.lo)))
end

function _getindex_hiprec(r::StepRangeLen{<:Any,<:TwicePrecision,<:TwicePrecision}, i::Integer)
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
u = i - r.offset
shift_hi, shift_lo = u*r.step.hi, u*r.step.lo
x_hi, x_lo = add12(r.ref.hi, shift_hi)
x_hi, x_lo = add12(x_hi, x_lo + (shift_lo + r.ref.lo))
TwicePrecision(x_hi, x_lo)
end

function getindex(r::StepRangeLen{T,<:TwicePrecision,<:TwicePrecision}, s::OrdinalRange{<:Integer}) where T
function getindex(r::StepRangeLen{T,<:TwicePrecision,<:TwicePrecision}, s::OrdinalRange{S}) where {T, S<:Integer}
@boundscheck checkbounds(r, s)
soffset = 1 + round(Int, (r.offset - first(s))/step(s))
soffset = clamp(soffset, 1, length(s))
ioffset = first(s) + (soffset-1)*step(s)
if step(s) == 1 || length(s) < 2
newstep = r.step
else
newstep = twiceprecision(r.step*step(s), nbitslen(T, length(s), soffset))
end
if ioffset == r.offset
StepRangeLen(r.ref, newstep, length(s), max(1,soffset))
if S === Bool
if length(s) == 0
return StepRangeLen(r.ref, r.step, 0, 1)
elseif length(s) == 1
if first(s)
return StepRangeLen(r.ref, r.step, 1, 1)
else
return StepRangeLen(r.ref, r.step, 0, 1)
end
else # length(s) == 2
return StepRangeLen(r[2], step(r), 1, 1)
end
else
StepRangeLen(r.ref + (ioffset-r.offset)*r.step, newstep, length(s), max(1,soffset))
soffset = 1 + round(Int, (r.offset - first(s))/step(s))
soffset = clamp(soffset, 1, length(s))
ioffset = first(s) + (soffset-1)*step(s)
if step(s) == 1 || length(s) < 2
newstep = r.step
else
newstep = twiceprecision(r.step*step(s), nbitslen(T, length(s), soffset))
end
if ioffset == r.offset
return StepRangeLen(r.ref, newstep, length(s), max(1,soffset))
else
return StepRangeLen(r.ref + (ioffset-r.offset)*r.step, newstep, length(s), max(1,soffset))
end
end
end

Expand Down
Loading