Skip to content

Commit

Permalink
Make StepRangeLen in operations which may produce zero step (JuliaL…
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott authored and johanmon committed Jul 5, 2021
1 parent a836fa1 commit dbdf40b
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 20 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ Standard library changes
* `count` and `findall` now accept an `AbstractChar` argument to search for a character in a string ([#38675]).
* `range` now supports the `range(start, stop)` and `range(start, stop, length)` methods ([#39228]).
* `range` now supports `start` as an optional keyword argument ([#38041]).
* Some operations on ranges will return a `StepRangeLen` instead of a `StepRange`, to allow the resulting step to be zero. Previously, `λ .* (1:9)` gave an error when `λ = 0`. ([#40320])
* `islowercase` and `isuppercase` are now compliant with the Unicode lower/uppercase categories ([#38574]).
* `iseven` and `isodd` functions now support non-`Integer` numeric types ([#38976]).
* `escape_string` can now receive a collection of characters in the keyword
Expand Down
11 changes: 8 additions & 3 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1153,15 +1153,20 @@ broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::LinRange, x::Number) = LinRa
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::LinRange) = LinRange(x - r.start, x - r.stop, length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r1::AbstractRange, r2::AbstractRange) = r1 - r2

broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::AbstractRange) = range(x*first(r), step=x*step(r), length=length(r))
# at present Base.range_start_step_length(1,0,5) is an error, so for 0 .* (-2:2) we explicitly construct StepRangeLen:
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::AbstractRange) = StepRangeLen(x*first(r), x*step(r), length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::StepRangeLen{T}) where {T} =
StepRangeLen{typeof(x*T(r.ref))}(x*r.ref, x*r.step, length(r), r.offset)
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::LinRange) = LinRange(x * r.start, x * r.stop, r.len)
# separate in case of noncommutative multiplication
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::AbstractRange, x::Number) = range(first(r)*x, step=step(r)*x, length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::AbstractFloat, r::OrdinalRange) =
Base.range_start_step_length(x*first(r), x, length(r)) # 0.2 .* (-2:2) needs TwicePrecision
# separate in case of noncommutative multiplication:
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::AbstractRange, x::Number) = StepRangeLen(first(r)*x, step(r)*x, length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::StepRangeLen{T}, x::Number) where {T} =
StepRangeLen{typeof(T(r.ref)*x)}(r.ref*x, r.step*x, length(r), r.offset)
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::LinRange, x::Number) = LinRange(r.start * x, r.stop * x, r.len)
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::OrdinalRange, x::AbstractFloat) =
Base.range_start_step_length(first(r)*x, x, length(r))

broadcasted(::DefaultArrayStyle{1}, ::typeof(/), r::AbstractRange, x::Number) = range(first(r)/x, step=step(r)/x, length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(/), r::StepRangeLen{T}, x::Number) where {T} =
Expand Down
10 changes: 9 additions & 1 deletion base/range.jl
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,14 @@ end
show(io::IO, r::AbstractRange) = print(io, repr(first(r)), ':', repr(step(r)), ':', repr(last(r)))
show(io::IO, r::UnitRange) = print(io, repr(first(r)), ':', repr(last(r)))
show(io::IO, r::OneTo) = print(io, "Base.OneTo(", r.stop, ")")
function show(io::IO, r::StepRangeLen)
if step(r) != 0
print(io, repr(first(r)), ':', repr(step(r)), ':', repr(last(r)))
else
# ugly temporary printing, to avoid 0:0:0 etc.
print(io, "StepRangeLen(", repr(first(r)), ", ", repr(step(r)), ", ", repr(length(r)), ")")
end
end

function ==(r::T, s::T) where {T<:AbstractRange}
isempty(r) && return isempty(s)
Expand Down Expand Up @@ -1238,7 +1246,7 @@ function _define_range_op(@nospecialize f)
r1l = length(r1)
(r1l == length(r2) ||
throw(DimensionMismatch("argument dimensions must match: length of r1 is $r1l, length of r2 is $(length(r2))")))
range($f(first(r1), first(r2)), step=$f(step(r1), step(r2)), length=r1l)
StepRangeLen($f(first(r1), first(r2)), $f(step(r1), step(r2)), r1l)
end

function $f(r1::LinRange{T}, r2::LinRange{T}) where T
Expand Down
8 changes: 5 additions & 3 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,8 @@ end
end
end

# issue 40309
@test Base.broadcasted_kwsyntax(+, [1], [2]) isa Broadcast.Broadcasted{<:Any, <:Any, typeof(+)}
@test Broadcast.BroadcastFunction(+)(2:3, 2:3) === 4:2:6
@testset "Issue #40309: still gives a range after #40320" begin
@test Base.broadcasted_kwsyntax(+, [1], [2]) isa Broadcast.Broadcasted{<:Any, <:Any, typeof(+)}
@test Broadcast.BroadcastFunction(+)(2:3, 2:3) == 4:2:6
@test Broadcast.BroadcastFunction(+)(2:3, 2:3) isa AbstractRange
end
38 changes: 25 additions & 13 deletions test/ranges.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1082,7 +1082,7 @@ end
@test sprint(show, StepRange(1, 2, 5)) == "1:2:5"
end

@testset "Issue 11049 and related" begin
@testset "Issue 11049, and related" begin
@test promote(range(0f0, stop=1f0, length=3), range(0., stop=5., length=2)) ===
(range(0., stop=1., length=3), range(0., stop=5., length=2))
@test convert(LinRange{Float64}, range(0., stop=1., length=3)) === LinRange(0., 1., 3)
Expand Down Expand Up @@ -1144,6 +1144,7 @@ end
@test [reverse(range(1.0, stop=27.0, length=1275));] ==
reverse([range(1.0, stop=27.0, length=1275);])
end

@testset "PR 12200 and related" begin
for _r in (1:2:100, 1:100, 1f0:2f0:100f0, 1.0:2.0:100.0,
range(1, stop=100, length=10), range(1f0, stop=100f0, length=10))
Expand Down Expand Up @@ -1288,8 +1289,8 @@ end
@test_throws BoundsError r[4]
@test_throws BoundsError r[0]
@test broadcast(+, r, 1) === 2:4
@test 2*r === 2:2:6
@test r + r === 2:2:6
@test 2*r == 2:2:6
@test r + r == 2:2:6
k = 0
for i in r
@test i == (k += 1)
Expand Down Expand Up @@ -1432,14 +1433,14 @@ end
@test @inferred(r .+ x) === 3:7
@test @inferred(r .- x) === -1:3
@test @inferred(x .- r) === 1:-1:-3
@test @inferred(x .* r) === 2:2:10
@test @inferred(r .* x) === 2:2:10
@test @inferred(x .* r) == 2:2:10
@test @inferred(r .* x) == 2:2:10
@test @inferred(r ./ x) === 0.5:0.5:2.5
@test @inferred(x ./ r) == 2 ./ [r;] && isa(x ./ r, Vector{Float64})
@test @inferred(r .\ x) == 2 ./ [r;] && isa(x ./ r, Vector{Float64})
@test @inferred(x .\ r) === 0.5:0.5:2.5

@test @inferred(2 .* (r .+ 1) .+ 2) === 6:2:14
@test @inferred(2 .* (r .+ 1) .+ 2) == 6:2:14
end

@testset "Bad range calls" begin
Expand Down Expand Up @@ -1564,17 +1565,22 @@ end # module NonStandardIntegerRangeTest
end

@testset "constant-valued ranges (issues #10391 and #29052)" begin
for r in ((1:4), (1:1:4), (1.0:4.0))
is_int = eltype(r) === Int
@test @inferred(0 * r) == [0.0, 0.0, 0.0, 0.0] broken=is_int
@test @inferred(0 .* r) == [0.0, 0.0, 0.0, 0.0] broken=is_int
@test @inferred(r + (4:-1:1)) == [5.0, 5.0, 5.0, 5.0] broken=is_int
@test @inferred(r .+ (4:-1:1)) == [5.0, 5.0, 5.0, 5.0] broken=is_int
@testset "with $(nameof(typeof(r))) of $(eltype(r))" for r in ((1:4), (1:1:4), StepRangeLen(1,1,4), (1.0:4.0))
@test @inferred(0 * r) == [0.0, 0.0, 0.0, 0.0]
@test @inferred(0 .* r) == [0.0, 0.0, 0.0, 0.0]
@test @inferred(r .* 0) == [0.0, 0.0, 0.0, 0.0]
@test @inferred(r + (4:-1:1)) == [5.0, 5.0, 5.0, 5.0]
@test @inferred(r .+ (4:-1:1)) == [5.0, 5.0, 5.0, 5.0]
@test @inferred(r - r) == [0.0, 0.0, 0.0, 0.0]
@test @inferred(r .- r) == [0.0, 0.0, 0.0, 0.0]

@test @inferred(r .+ (4.0:-1:1)) == [5.0, 5.0, 5.0, 5.0]
@test @inferred(0.0 * r) == [0.0, 0.0, 0.0, 0.0]
@test @inferred(0.0 .* r) == [0.0, 0.0, 0.0, 0.0]
@test @inferred(r / Inf) == [0.0, 0.0, 0.0, 0.0]
@test @inferred(r ./ Inf) == [0.0, 0.0, 0.0, 0.0]

@test eval(Meta.parse(repr(0 * r))) == [0.0, 0.0, 0.0, 0.0]
end

@test_broken @inferred(range(0, step=0, length=4)) == [0, 0, 0, 0]
Expand All @@ -1587,7 +1593,7 @@ end
@test @inferred(range(0.0, stop=0, length=4)) == [0.0, 0.0, 0.0, 0.0]

z4 = 0.0 * (1:4)
@test @inferred(z4 .+ (1:4)) === 1.0:1.0:4.0
@test @inferred(z4 .+ (1:4)) == 1.0:1.0:4.0
@test @inferred(z4 .+ z4) === z4
end

Expand Down Expand Up @@ -1889,3 +1895,9 @@ end
@test_throws BoundsError r[true:true:false]
@test_throws BoundsError r[true:true:true]
end

@testset "PR 40320 nanosoldier" begin
@test 0.2 * (-2:2) == -0.4:0.2:0.4 # from tests of AbstractFFTs, needs Base.TwicePrecision
@test 0.2f0 * (-2:2) == Float32.(-0.4:0.2:0.4) # likewise needs Float64
@test 0.2 * (-2:1:2) == -0.4:0.2:0.4
end

0 comments on commit dbdf40b

Please sign in to comment.