From 85fac87f6304b92bcd8efbd75c32dfe214ee2ddf Mon Sep 17 00:00:00 2001 From: Erik Schnetter Date: Tue, 13 Sep 2022 04:05:11 -0400 Subject: [PATCH] LinearAlgebra: Allow arrays to be zero-preserving (#46340) --- .../LinearAlgebra/src/structuredbroadcast.jl | 1 + stdlib/LinearAlgebra/test/diagonal.jl | 23 +++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/stdlib/LinearAlgebra/src/structuredbroadcast.jl b/stdlib/LinearAlgebra/src/structuredbroadcast.jl index 95a1842702291..ccf95f88a1bee 100644 --- a/stdlib/LinearAlgebra/src/structuredbroadcast.jl +++ b/stdlib/LinearAlgebra/src/structuredbroadcast.jl @@ -126,6 +126,7 @@ fails as `zero(::Tuple{Int})` is not defined. However, """ iszerodefined(::Type) = false iszerodefined(::Type{<:Number}) = true +iszerodefined(::Type{<:AbstractArray{T}}) where T = iszerodefined(T) fzeropreserving(bc) = (v = fzero(bc); !ismissing(v) && (iszerodefined(typeof(v)) ? iszero(v) : v == 0)) # Like sparse matrices, we assume that the zero-preservation property of a broadcasted diff --git a/stdlib/LinearAlgebra/test/diagonal.jl b/stdlib/LinearAlgebra/test/diagonal.jl index 4c54eb6a11003..3e6f456c3de1e 100644 --- a/stdlib/LinearAlgebra/test/diagonal.jl +++ b/stdlib/LinearAlgebra/test/diagonal.jl @@ -1106,4 +1106,27 @@ end @test outTri === mul!(outTri, UTriA, D, 2, 1)::Tri == mul!(out, Matrix(UTriA), D, 2, 1) end +struct SMatrix1{T} <: AbstractArray{T,2} + elt::T +end +Base.:(==)(A::SMatrix1, B::SMatrix1) = A.elt == B.elt +Base.zero(::Type{SMatrix1{T}}) where {T} = SMatrix1(zero(T)) +Base.iszero(A::SMatrix1) = iszero(A.elt) +Base.getindex(A::SMatrix1, inds...) = A.elt +Base.size(::SMatrix1) = (1, 1) +@testset "map for Diagonal matrices (#46292)" begin + A = Diagonal([1]) + @test A isa Diagonal{Int,Vector{Int}} + @test 2*A isa Diagonal{Int,Vector{Int}} + @test A.+1 isa Matrix{Int} + # Numeric element types remain diagonal + B = map(SMatrix1, A) + @test B == fill(SMatrix1(1), 1, 1) + @test B isa Diagonal{SMatrix1{Int},Vector{SMatrix1{Int}}} + # Non-numeric element types become dense + C = map(a -> SMatrix1(string(a)), A) + @test C == fill(SMatrix1(string(1)), 1, 1) + @test C isa Matrix{SMatrix1{String}} +end + end # module TestDiagonal