Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make StepRangeLen in operations which may produce zero step #40320

Merged
merged 14 commits into from
May 15, 2021
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ Standard library changes
* `count` and `findall` now accept an `AbstractChar` argument to search for a character in a string ([#38675]).
* `range` now supports the `range(start, stop)` and `range(start, stop, length)` methods ([#39228]).
* `range` now supports `start` as an optional keyword argument ([#38041]).
* Some operations on ranges will return a `StepRangeLen` instead of a `StepRange`, to allow the resulting step to be zero. Previously, `λ .* (1:9)` gave an error when `λ = 0`. ([#40320])
* `islowercase` and `isuppercase` are now compliant with the Unicode lower/uppercase categories ([#38574]).
* `iseven` and `isodd` functions now support non-`Integer` numeric types ([#38976]).
* `escape_string` can now receive a collection of characters in the keyword
Expand Down
11 changes: 8 additions & 3 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1153,15 +1153,20 @@ broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::LinRange, x::Number) = LinRa
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::LinRange) = LinRange(x - r.start, x - r.stop, length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r1::AbstractRange, r2::AbstractRange) = r1 - r2

broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::AbstractRange) = range(x*first(r), step=x*step(r), length=length(r))
# at present Base.range_start_step_length(1,0,5) is an error, so for 0 .* (-2:2) we explicitly construct StepRangeLen:
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::AbstractRange) = StepRangeLen(x*first(r), x*step(r), length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::StepRangeLen{T}) where {T} =
StepRangeLen{typeof(x*T(r.ref))}(x*r.ref, x*r.step, length(r), r.offset)
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::LinRange) = LinRange(x * r.start, x * r.stop, r.len)
# separate in case of noncommutative multiplication
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::AbstractRange, x::Number) = range(first(r)*x, step=step(r)*x, length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::AbstractFloat, r::OrdinalRange) =
Base.range_start_step_length(x*first(r), x, length(r)) # 0.2 .* (-2:2) needs TwicePrecision
# separate in case of noncommutative multiplication:
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::AbstractRange, x::Number) = StepRangeLen(first(r)*x, step(r)*x, length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::StepRangeLen{T}, x::Number) where {T} =
StepRangeLen{typeof(T(r.ref)*x)}(r.ref*x, r.step*x, length(r), r.offset)
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::LinRange, x::Number) = LinRange(r.start * x, r.stop * x, r.len)
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::OrdinalRange, x::AbstractFloat) =
Base.range_start_step_length(first(r)*x, x, length(r))

broadcasted(::DefaultArrayStyle{1}, ::typeof(/), r::AbstractRange, x::Number) = range(first(r)/x, step=step(r)/x, length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(/), r::StepRangeLen{T}, x::Number) where {T} =
Expand Down
10 changes: 9 additions & 1 deletion base/range.jl
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,14 @@ end
show(io::IO, r::AbstractRange) = print(io, repr(first(r)), ':', repr(step(r)), ':', repr(last(r)))
show(io::IO, r::UnitRange) = print(io, repr(first(r)), ':', repr(last(r)))
show(io::IO, r::OneTo) = print(io, "Base.OneTo(", r.stop, ")")
function show(io::IO, r::StepRangeLen)
if step(r) != 0
print(io, repr(first(r)), ':', repr(step(r)), ':', repr(last(r)))
else
# ugly temporary printing, to avoid 0:0:0 etc.
print(io, "StepRangeLen(", repr(first(r)), ", ", repr(step(r)), ", ", repr(length(r)), ")")
end
end

function ==(r::T, s::T) where {T<:AbstractRange}
isempty(r) && return isempty(s)
Expand Down Expand Up @@ -1238,7 +1246,7 @@ function _define_range_op(@nospecialize f)
r1l = length(r1)
(r1l == length(r2) ||
throw(DimensionMismatch("argument dimensions must match: length of r1 is $r1l, length of r2 is $(length(r2))")))
range($f(first(r1), first(r2)), step=$f(step(r1), step(r2)), length=r1l)
StepRangeLen($f(first(r1), first(r2)), $f(step(r1), step(r2)), r1l)
end

function $f(r1::LinRange{T}, r2::LinRange{T}) where T
Expand Down
8 changes: 5 additions & 3 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,8 @@ end
end
end

# issue 40309
@test Base.broadcasted_kwsyntax(+, [1], [2]) isa Broadcast.Broadcasted{<:Any, <:Any, typeof(+)}
@test Broadcast.BroadcastFunction(+)(2:3, 2:3) === 4:2:6
@testset "Issue #40309: still gives a range after #40320" begin
@test Base.broadcasted_kwsyntax(+, [1], [2]) isa Broadcast.Broadcasted{<:Any, <:Any, typeof(+)}
@test Broadcast.BroadcastFunction(+)(2:3, 2:3) == 4:2:6
@test Broadcast.BroadcastFunction(+)(2:3, 2:3) isa AbstractRange
end
38 changes: 25 additions & 13 deletions test/ranges.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1082,7 +1082,7 @@ end
@test sprint(show, StepRange(1, 2, 5)) == "1:2:5"
end

@testset "Issue 11049 and related" begin
@testset "Issue 11049, and related" begin
@test promote(range(0f0, stop=1f0, length=3), range(0., stop=5., length=2)) ===
(range(0., stop=1., length=3), range(0., stop=5., length=2))
@test convert(LinRange{Float64}, range(0., stop=1., length=3)) === LinRange(0., 1., 3)
Expand Down Expand Up @@ -1144,6 +1144,7 @@ end
@test [reverse(range(1.0, stop=27.0, length=1275));] ==
reverse([range(1.0, stop=27.0, length=1275);])
end

@testset "PR 12200 and related" begin
for _r in (1:2:100, 1:100, 1f0:2f0:100f0, 1.0:2.0:100.0,
range(1, stop=100, length=10), range(1f0, stop=100f0, length=10))
Expand Down Expand Up @@ -1288,8 +1289,8 @@ end
@test_throws BoundsError r[4]
@test_throws BoundsError r[0]
@test broadcast(+, r, 1) === 2:4
@test 2*r === 2:2:6
@test r + r === 2:2:6
@test 2*r == 2:2:6
@test r + r == 2:2:6
k = 0
for i in r
@test i == (k += 1)
Expand Down Expand Up @@ -1432,14 +1433,14 @@ end
@test @inferred(r .+ x) === 3:7
@test @inferred(r .- x) === -1:3
@test @inferred(x .- r) === 1:-1:-3
@test @inferred(x .* r) === 2:2:10
@test @inferred(r .* x) === 2:2:10
@test @inferred(x .* r) == 2:2:10
@test @inferred(r .* x) == 2:2:10
@test @inferred(r ./ x) === 0.5:0.5:2.5
@test @inferred(x ./ r) == 2 ./ [r;] && isa(x ./ r, Vector{Float64})
@test @inferred(r .\ x) == 2 ./ [r;] && isa(x ./ r, Vector{Float64})
@test @inferred(x .\ r) === 0.5:0.5:2.5

@test @inferred(2 .* (r .+ 1) .+ 2) === 6:2:14
@test @inferred(2 .* (r .+ 1) .+ 2) == 6:2:14
end

@testset "Bad range calls" begin
Expand Down Expand Up @@ -1564,17 +1565,22 @@ end # module NonStandardIntegerRangeTest
end

@testset "constant-valued ranges (issues #10391 and #29052)" begin
for r in ((1:4), (1:1:4), (1.0:4.0))
is_int = eltype(r) === Int
@test @inferred(0 * r) == [0.0, 0.0, 0.0, 0.0] broken=is_int
@test @inferred(0 .* r) == [0.0, 0.0, 0.0, 0.0] broken=is_int
@test @inferred(r + (4:-1:1)) == [5.0, 5.0, 5.0, 5.0] broken=is_int
@test @inferred(r .+ (4:-1:1)) == [5.0, 5.0, 5.0, 5.0] broken=is_int
@testset "with $(nameof(typeof(r))) of $(eltype(r))" for r in ((1:4), (1:1:4), StepRangeLen(1,1,4), (1.0:4.0))
@test @inferred(0 * r) == [0.0, 0.0, 0.0, 0.0]
@test @inferred(0 .* r) == [0.0, 0.0, 0.0, 0.0]
@test @inferred(r .* 0) == [0.0, 0.0, 0.0, 0.0]
@test @inferred(r + (4:-1:1)) == [5.0, 5.0, 5.0, 5.0]
@test @inferred(r .+ (4:-1:1)) == [5.0, 5.0, 5.0, 5.0]
@test @inferred(r - r) == [0.0, 0.0, 0.0, 0.0]
@test @inferred(r .- r) == [0.0, 0.0, 0.0, 0.0]

@test @inferred(r .+ (4.0:-1:1)) == [5.0, 5.0, 5.0, 5.0]
@test @inferred(0.0 * r) == [0.0, 0.0, 0.0, 0.0]
@test @inferred(0.0 .* r) == [0.0, 0.0, 0.0, 0.0]
@test @inferred(r / Inf) == [0.0, 0.0, 0.0, 0.0]
@test @inferred(r ./ Inf) == [0.0, 0.0, 0.0, 0.0]

@test eval(Meta.parse(repr(0 * r))) == [0.0, 0.0, 0.0, 0.0]
end

@test_broken @inferred(range(0, step=0, length=4)) == [0, 0, 0, 0]
Expand All @@ -1587,7 +1593,7 @@ end
@test @inferred(range(0.0, stop=0, length=4)) == [0.0, 0.0, 0.0, 0.0]

z4 = 0.0 * (1:4)
@test @inferred(z4 .+ (1:4)) === 1.0:1.0:4.0
@test @inferred(z4 .+ (1:4)) == 1.0:1.0:4.0
@test @inferred(z4 .+ z4) === z4
end

Expand Down Expand Up @@ -1889,3 +1895,9 @@ end
@test_throws BoundsError r[true:true:false]
@test_throws BoundsError r[true:true:true]
end

@testset "PR 40320 nanosoldier" begin
@test 0.2 * (-2:2) == -0.4:0.2:0.4 # from tests of AbstractFFTs, needs Base.TwicePrecision
@test 0.2f0 * (-2:2) == Float32.(-0.4:0.2:0.4) # likewise needs Float64
@test 0.2 * (-2:1:2) == -0.4:0.2:0.4
end