From dbdf40bb86b32daba22ebda5d0a6c48d8bb9825f Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 15 May 2021 12:39:39 -0400 Subject: [PATCH] Make `StepRangeLen` in operations which may produce zero step (#40320) --- NEWS.md | 1 + base/broadcast.jl | 11 ++++++++--- base/range.jl | 10 +++++++++- test/broadcast.jl | 8 +++++--- test/ranges.jl | 38 +++++++++++++++++++++++++------------- 5 files changed, 48 insertions(+), 20 deletions(-) diff --git a/NEWS.md b/NEWS.md index 7bacc44220b9ea..ea05d66aac3a6e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -60,6 +60,7 @@ Standard library changes * `count` and `findall` now accept an `AbstractChar` argument to search for a character in a string ([#38675]). * `range` now supports the `range(start, stop)` and `range(start, stop, length)` methods ([#39228]). * `range` now supports `start` as an optional keyword argument ([#38041]). +* Some operations on ranges will return a `StepRangeLen` instead of a `StepRange`, to allow the resulting step to be zero. Previously, `λ .* (1:9)` gave an error when `λ = 0`. ([#40320]) * `islowercase` and `isuppercase` are now compliant with the Unicode lower/uppercase categories ([#38574]). * `iseven` and `isodd` functions now support non-`Integer` numeric types ([#38976]). * `escape_string` can now receive a collection of characters in the keyword diff --git a/base/broadcast.jl b/base/broadcast.jl index 4c3f91c50638fc..47bd01145cd7d7 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -1153,15 +1153,20 @@ broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::LinRange, x::Number) = LinRa broadcasted(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::LinRange) = LinRange(x - r.start, x - r.stop, length(r)) broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r1::AbstractRange, r2::AbstractRange) = r1 - r2 -broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::AbstractRange) = range(x*first(r), step=x*step(r), length=length(r)) +# at present Base.range_start_step_length(1,0,5) is an error, so for 0 .* (-2:2) we explicitly construct StepRangeLen: +broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::AbstractRange) = StepRangeLen(x*first(r), x*step(r), length(r)) broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::StepRangeLen{T}) where {T} = StepRangeLen{typeof(x*T(r.ref))}(x*r.ref, x*r.step, length(r), r.offset) broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::LinRange) = LinRange(x * r.start, x * r.stop, r.len) -# separate in case of noncommutative multiplication -broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::AbstractRange, x::Number) = range(first(r)*x, step=step(r)*x, length=length(r)) +broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::AbstractFloat, r::OrdinalRange) = + Base.range_start_step_length(x*first(r), x, length(r)) # 0.2 .* (-2:2) needs TwicePrecision +# separate in case of noncommutative multiplication: +broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::AbstractRange, x::Number) = StepRangeLen(first(r)*x, step(r)*x, length(r)) broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::StepRangeLen{T}, x::Number) where {T} = StepRangeLen{typeof(T(r.ref)*x)}(r.ref*x, r.step*x, length(r), r.offset) broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::LinRange, x::Number) = LinRange(r.start * x, r.stop * x, r.len) +broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::OrdinalRange, x::AbstractFloat) = + Base.range_start_step_length(first(r)*x, x, length(r)) broadcasted(::DefaultArrayStyle{1}, ::typeof(/), r::AbstractRange, x::Number) = range(first(r)/x, step=step(r)/x, length=length(r)) broadcasted(::DefaultArrayStyle{1}, ::typeof(/), r::StepRangeLen{T}, x::Number) where {T} = diff --git a/base/range.jl b/base/range.jl index 41120e1bcdaf0d..ca1f577ecf00aa 100644 --- a/base/range.jl +++ b/base/range.jl @@ -927,6 +927,14 @@ end show(io::IO, r::AbstractRange) = print(io, repr(first(r)), ':', repr(step(r)), ':', repr(last(r))) show(io::IO, r::UnitRange) = print(io, repr(first(r)), ':', repr(last(r))) show(io::IO, r::OneTo) = print(io, "Base.OneTo(", r.stop, ")") +function show(io::IO, r::StepRangeLen) + if step(r) != 0 + print(io, repr(first(r)), ':', repr(step(r)), ':', repr(last(r))) + else + # ugly temporary printing, to avoid 0:0:0 etc. + print(io, "StepRangeLen(", repr(first(r)), ", ", repr(step(r)), ", ", repr(length(r)), ")") + end +end function ==(r::T, s::T) where {T<:AbstractRange} isempty(r) && return isempty(s) @@ -1238,7 +1246,7 @@ function _define_range_op(@nospecialize f) r1l = length(r1) (r1l == length(r2) || throw(DimensionMismatch("argument dimensions must match: length of r1 is $r1l, length of r2 is $(length(r2))"))) - range($f(first(r1), first(r2)), step=$f(step(r1), step(r2)), length=r1l) + StepRangeLen($f(first(r1), first(r2)), $f(step(r1), step(r2)), r1l) end function $f(r1::LinRange{T}, r2::LinRange{T}) where T diff --git a/test/broadcast.jl b/test/broadcast.jl index 4c2b92a39bb9f4..57a65110a413a2 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -1045,6 +1045,8 @@ end end end -# issue 40309 -@test Base.broadcasted_kwsyntax(+, [1], [2]) isa Broadcast.Broadcasted{<:Any, <:Any, typeof(+)} -@test Broadcast.BroadcastFunction(+)(2:3, 2:3) === 4:2:6 +@testset "Issue #40309: still gives a range after #40320" begin + @test Base.broadcasted_kwsyntax(+, [1], [2]) isa Broadcast.Broadcasted{<:Any, <:Any, typeof(+)} + @test Broadcast.BroadcastFunction(+)(2:3, 2:3) == 4:2:6 + @test Broadcast.BroadcastFunction(+)(2:3, 2:3) isa AbstractRange +end diff --git a/test/ranges.jl b/test/ranges.jl index b0693a32527d24..dea5d9b6753056 100644 --- a/test/ranges.jl +++ b/test/ranges.jl @@ -1082,7 +1082,7 @@ end @test sprint(show, StepRange(1, 2, 5)) == "1:2:5" end -@testset "Issue 11049 and related" begin +@testset "Issue 11049, and related" begin @test promote(range(0f0, stop=1f0, length=3), range(0., stop=5., length=2)) === (range(0., stop=1., length=3), range(0., stop=5., length=2)) @test convert(LinRange{Float64}, range(0., stop=1., length=3)) === LinRange(0., 1., 3) @@ -1144,6 +1144,7 @@ end @test [reverse(range(1.0, stop=27.0, length=1275));] == reverse([range(1.0, stop=27.0, length=1275);]) end + @testset "PR 12200 and related" begin for _r in (1:2:100, 1:100, 1f0:2f0:100f0, 1.0:2.0:100.0, range(1, stop=100, length=10), range(1f0, stop=100f0, length=10)) @@ -1288,8 +1289,8 @@ end @test_throws BoundsError r[4] @test_throws BoundsError r[0] @test broadcast(+, r, 1) === 2:4 - @test 2*r === 2:2:6 - @test r + r === 2:2:6 + @test 2*r == 2:2:6 + @test r + r == 2:2:6 k = 0 for i in r @test i == (k += 1) @@ -1432,14 +1433,14 @@ end @test @inferred(r .+ x) === 3:7 @test @inferred(r .- x) === -1:3 @test @inferred(x .- r) === 1:-1:-3 - @test @inferred(x .* r) === 2:2:10 - @test @inferred(r .* x) === 2:2:10 + @test @inferred(x .* r) == 2:2:10 + @test @inferred(r .* x) == 2:2:10 @test @inferred(r ./ x) === 0.5:0.5:2.5 @test @inferred(x ./ r) == 2 ./ [r;] && isa(x ./ r, Vector{Float64}) @test @inferred(r .\ x) == 2 ./ [r;] && isa(x ./ r, Vector{Float64}) @test @inferred(x .\ r) === 0.5:0.5:2.5 - @test @inferred(2 .* (r .+ 1) .+ 2) === 6:2:14 + @test @inferred(2 .* (r .+ 1) .+ 2) == 6:2:14 end @testset "Bad range calls" begin @@ -1564,17 +1565,22 @@ end # module NonStandardIntegerRangeTest end @testset "constant-valued ranges (issues #10391 and #29052)" begin - for r in ((1:4), (1:1:4), (1.0:4.0)) - is_int = eltype(r) === Int - @test @inferred(0 * r) == [0.0, 0.0, 0.0, 0.0] broken=is_int - @test @inferred(0 .* r) == [0.0, 0.0, 0.0, 0.0] broken=is_int - @test @inferred(r + (4:-1:1)) == [5.0, 5.0, 5.0, 5.0] broken=is_int - @test @inferred(r .+ (4:-1:1)) == [5.0, 5.0, 5.0, 5.0] broken=is_int + @testset "with $(nameof(typeof(r))) of $(eltype(r))" for r in ((1:4), (1:1:4), StepRangeLen(1,1,4), (1.0:4.0)) + @test @inferred(0 * r) == [0.0, 0.0, 0.0, 0.0] + @test @inferred(0 .* r) == [0.0, 0.0, 0.0, 0.0] + @test @inferred(r .* 0) == [0.0, 0.0, 0.0, 0.0] + @test @inferred(r + (4:-1:1)) == [5.0, 5.0, 5.0, 5.0] + @test @inferred(r .+ (4:-1:1)) == [5.0, 5.0, 5.0, 5.0] + @test @inferred(r - r) == [0.0, 0.0, 0.0, 0.0] + @test @inferred(r .- r) == [0.0, 0.0, 0.0, 0.0] + @test @inferred(r .+ (4.0:-1:1)) == [5.0, 5.0, 5.0, 5.0] @test @inferred(0.0 * r) == [0.0, 0.0, 0.0, 0.0] @test @inferred(0.0 .* r) == [0.0, 0.0, 0.0, 0.0] @test @inferred(r / Inf) == [0.0, 0.0, 0.0, 0.0] @test @inferred(r ./ Inf) == [0.0, 0.0, 0.0, 0.0] + + @test eval(Meta.parse(repr(0 * r))) == [0.0, 0.0, 0.0, 0.0] end @test_broken @inferred(range(0, step=0, length=4)) == [0, 0, 0, 0] @@ -1587,7 +1593,7 @@ end @test @inferred(range(0.0, stop=0, length=4)) == [0.0, 0.0, 0.0, 0.0] z4 = 0.0 * (1:4) - @test @inferred(z4 .+ (1:4)) === 1.0:1.0:4.0 + @test @inferred(z4 .+ (1:4)) == 1.0:1.0:4.0 @test @inferred(z4 .+ z4) === z4 end @@ -1889,3 +1895,9 @@ end @test_throws BoundsError r[true:true:false] @test_throws BoundsError r[true:true:true] end + +@testset "PR 40320 nanosoldier" begin + @test 0.2 * (-2:2) == -0.4:0.2:0.4 # from tests of AbstractFFTs, needs Base.TwicePrecision + @test 0.2f0 * (-2:2) == Float32.(-0.4:0.2:0.4) # likewise needs Float64 + @test 0.2 * (-2:1:2) == -0.4:0.2:0.4 +end