From 507ed3fb5f1ec40a1a25864f7b1b229cfd6ffa09 Mon Sep 17 00:00:00 2001 From: Sacha Verweij Date: Sun, 29 Oct 2017 11:59:04 -0700 Subject: [PATCH] Bind I to UniformScaling(true) rather than UniformScaling(1) for better promotion behavior. --- NEWS.md | 3 +++ base/linalg/uniformscaling.jl | 20 ++++++++++---------- test/core.jl | 2 +- test/linalg/uniformscaling.jl | 13 ++++++++++++- 4 files changed, 26 insertions(+), 12 deletions(-) diff --git a/NEWS.md b/NEWS.md index cdf940755907f..d2e14b873ed5c 100644 --- a/NEWS.md +++ b/NEWS.md @@ -255,6 +255,9 @@ This section lists changes that do not have deprecation warnings. * All command line arguments passed via `-e`, `-E`, and `-L` will be executed in the order given on the command line ([#23665]). + * `I` now yields `UniformScaling{Bool}(true)` rather than `UniformScaling{Int64}(1)` + to better preserve types in operations involving `I` ([#24396]). + * The return type of `reinterpret` has changed to `ReinterpretArray`. `reinterpret` on sparse arrays has been discontinued. diff --git a/base/linalg/uniformscaling.jl b/base/linalg/uniformscaling.jl index dfca2146330ec..2fdbdd9c31354 100644 --- a/base/linalg/uniformscaling.jl +++ b/base/linalg/uniformscaling.jl @@ -47,7 +47,7 @@ julia> [1 2im 3; 1im 2 3] * I 0+1im 2+0im 3+0im ``` """ -const I = UniformScaling(1) +const I = UniformScaling(true) eltype(::Type{UniformScaling{T}}) where {T} = T ndims(J::UniformScaling) = 2 @@ -99,7 +99,7 @@ for (t1, t2) in ((:UnitUpperTriangular, :UpperTriangular), ($op)(UL::$t2, J::UniformScaling) = ($t2)(($op)(UL.data, J)) function ($op)(UL::$t1, J::UniformScaling) - ULnew = copy_oftype(UL.data, promote_type(eltype(UL), eltype(J))) + ULnew = copy_oftype(UL.data, Base.Broadcast._broadcast_eltype($op, UL, J)) for i = 1:size(ULnew, 1) ULnew[i,i] = ($op)(1, J.λ) end @@ -110,7 +110,7 @@ for (t1, t2) in ((:UnitUpperTriangular, :UpperTriangular), end function (-)(J::UniformScaling, UL::Union{UpperTriangular,UnitUpperTriangular}) - ULnew = similar(parent(UL), promote_type(eltype(J), eltype(UL))) + ULnew = similar(parent(UL), Base.Broadcast._broadcast_eltype(-, J, UL)) n = size(ULnew, 1) ULold = UL.data for j = 1:n @@ -126,7 +126,7 @@ function (-)(J::UniformScaling, UL::Union{UpperTriangular,UnitUpperTriangular}) return UpperTriangular(ULnew) end function (-)(J::UniformScaling, UL::Union{LowerTriangular,UnitLowerTriangular}) - ULnew = similar(parent(UL), promote_type(eltype(J), eltype(UL))) + ULnew = similar(parent(UL), Base.Broadcast._broadcast_eltype(-, J, UL)) n = size(ULnew, 1) ULold = UL.data for j = 1:n @@ -142,9 +142,9 @@ function (-)(J::UniformScaling, UL::Union{LowerTriangular,UnitLowerTriangular}) return LowerTriangular(ULnew) end -function (+)(A::AbstractMatrix{TA}, J::UniformScaling{TJ}) where {TA,TJ} +function (+)(A::AbstractMatrix, J::UniformScaling) n = checksquare(A) - B = similar(A, promote_type(TA,TJ)) + B = similar(A, Base.Broadcast._broadcast_eltype(+, A, J)) copy!(B,A) @inbounds for i = 1:n B[i,i] += J.λ @@ -152,18 +152,18 @@ function (+)(A::AbstractMatrix{TA}, J::UniformScaling{TJ}) where {TA,TJ} B end -function (-)(A::AbstractMatrix{TA}, J::UniformScaling{TJ}) where {TA,TJ<:Number} +function (-)(A::AbstractMatrix, J::UniformScaling) n = checksquare(A) - B = similar(A, promote_type(TA,TJ)) + B = similar(A, Base.Broadcast._broadcast_eltype(-, A, J)) copy!(B, A) @inbounds for i = 1:n B[i,i] -= J.λ end B end -function (-)(J::UniformScaling{TJ}, A::AbstractMatrix{TA}) where {TA,TJ<:Number} +function (-)(J::UniformScaling, A::AbstractMatrix) n = checksquare(A) - B = convert(AbstractMatrix{promote_type(TJ,TA)}, -A) + B = convert(AbstractMatrix{Base.Broadcast._broadcast_eltype(-, J, A)}, -A) @inbounds for j = 1:n B[j,j] += J.λ end diff --git a/test/core.jl b/test/core.jl index 71273904203a1..23be56778df78 100644 --- a/test/core.jl +++ b/test/core.jl @@ -1919,7 +1919,7 @@ test5536(a::Union{Real, AbstractArray}) = "Non-splatting" # issue #6142 import Base: + mutable struct A6142 <: AbstractMatrix{Float64}; end -+(x::A6142, y::UniformScaling{TJ}) where {TJ} = "UniformScaling method called" ++(x::A6142, y::UniformScaling) = "UniformScaling method called" +(x::A6142, y::AbstractArray) = "AbstractArray method called" @test A6142() + I == "UniformScaling method called" +(x::A6142, y::AbstractRange) = "AbstractRange method called" #16324 ambiguity diff --git a/test/linalg/uniformscaling.jl b/test/linalg/uniformscaling.jl index faebd6a282912..ba55e648d972c 100644 --- a/test/linalg/uniformscaling.jl +++ b/test/linalg/uniformscaling.jl @@ -51,7 +51,7 @@ end end @testset "det and logdet" begin - @test det(I) === 1 + @test det(I) === true @test det(1.0I) === 1.0 @test det(0I) === 0 @test det(0.0I) === 0.0 @@ -216,3 +216,14 @@ end @test alltwos != 2I != alltwos # test generic path / inequality off diag @test rdenseI != I != rdenseI # test square matrix check end + +@testset "operations involving I should preserve eltype" begin + @test isa(Int8(1) + I, Int8) + @test isa(Float16(1) + I, Float16) + @test eltype(Int8(1)I) == Int8 + @test eltype(Float16(1)I) == Float16 + @test eltype(fill(Int8(1), 2, 2)I) == Int8 + @test eltype(fill(Float16(1), 2, 2)I) == Float16 + @test eltype(fill(Int8(1), 2, 2) + I) == Int8 + @test eltype(fill(Float16(1), 2, 2) + I) == Float16 +end