From f0636b2b4796ff44267e94514aa1f434b6aab77b Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 27 Feb 2020 08:55:27 -0500 Subject: [PATCH] More tests for Triangular, div and mult (#31831) * More tests for Triangular, div and mult * Long form function defs * avoid useless unwrapping-wrapping in triangular multiplication Co-authored-by: Daniel Karrasch --- stdlib/LinearAlgebra/src/triangular.jl | 124 ++++++++++++++---------- stdlib/LinearAlgebra/test/triangular.jl | 41 ++++++-- 2 files changed, 107 insertions(+), 58 deletions(-) diff --git a/stdlib/LinearAlgebra/src/triangular.jl b/stdlib/LinearAlgebra/src/triangular.jl index 2b6698ecf3840..bdf4199753891 100644 --- a/stdlib/LinearAlgebra/src/triangular.jl +++ b/stdlib/LinearAlgebra/src/triangular.jl @@ -688,18 +688,24 @@ mul!(C::AbstractVecOrMat, A::AbstractTriangular, adjB::Adjoint{<:Any,<:AbstractV mul!(C::AbstractVector , A::AbstractTriangular, B::AbstractVector) = lmul!(A, copyto!(C, B)) mul!(C::AbstractMatrix , A::AbstractTriangular, B::AbstractVecOrMat) = lmul!(A, copyto!(C, B)) mul!(C::AbstractVecOrMat, A::AbstractTriangular, B::AbstractVecOrMat) = lmul!(A, copyto!(C, B)) -mul!(C::AbstractVector , adjA::Adjoint{<:Any,<:AbstractTriangular}, B::AbstractVector) = - (A = adjA.parent; lmul!(adjoint(A), copyto!(C, B))) -mul!(C::AbstractMatrix , adjA::Adjoint{<:Any,<:AbstractTriangular}, B::AbstractVecOrMat) = - (A = adjA.parent; lmul!(adjoint(A), copyto!(C, B))) -mul!(C::AbstractVecOrMat, adjA::Adjoint{<:Any,<:AbstractTriangular}, B::AbstractVecOrMat) = - (A = adjA.parent; lmul!(adjoint(A), copyto!(C, B))) -mul!(C::AbstractVector , transA::Transpose{<:Any,<:AbstractTriangular}, B::AbstractVector) = - (A = transA.parent; lmul!(transpose(A), copyto!(C, B))) -mul!(C::AbstractMatrix , transA::Transpose{<:Any,<:AbstractTriangular}, B::AbstractVecOrMat) = - (A = transA.parent; lmul!(transpose(A), copyto!(C, B))) -mul!(C::AbstractVecOrMat, transA::Transpose{<:Any,<:AbstractTriangular}, B::AbstractVecOrMat) = - (A = transA.parent; lmul!(transpose(A), copyto!(C, B))) +function mul!(C::AbstractVector, adjA::Adjoint{<:Any,<:AbstractTriangular}, B::AbstractVector) + return lmul!(adjA, copyto!(C, B)) +end +function mul!(C::AbstractMatrix, adjA::Adjoint{<:Any,<:AbstractTriangular}, B::AbstractVecOrMat) + return lmul!(adjA, copyto!(C, B)) +end +function mul!(C::AbstractVecOrMat, adjA::Adjoint{<:Any,<:AbstractTriangular}, B::AbstractVecOrMat) + return lmul!(adjA, copyto!(C, B)) +end +function mul!(C::AbstractVector, transA::Transpose{<:Any,<:AbstractTriangular}, B::AbstractVector) + return lmul!(transA, copyto!(C, B)) +end +function mul!(C::AbstractMatrix, transA::Transpose{<:Any,<:AbstractTriangular}, B::AbstractVecOrMat) + return lmul!(transA, copyto!(C, B)) +end +function mul!(C::AbstractVecOrMat, transA::Transpose{<:Any,<:AbstractTriangular}, B::AbstractVecOrMat) + return lmul!(transA, copyto!(C, B)) +end @inline mul!(C::AbstractMatrix, A::Adjoint{<:Any,<:AbstractTriangular}, B::Adjoint{<:Any,<:AbstractVecOrMat}, alpha::Number, beta::Number) = mul!(C, A, copy(B), alpha, beta) @inline mul!(C::AbstractMatrix, A::Adjoint{<:Any,<:AbstractTriangular}, B::Transpose{<:Any,<:AbstractVecOrMat}, alpha::Number, beta::Number) = @@ -1718,44 +1724,62 @@ function rdiv!(A::StridedMatrix, transB::Transpose{<:Any,<:UnitLowerTriangular}) A end -lmul!(adjA::Adjoint{<:Any,<:Union{LowerTriangular,UnitLowerTriangular}}, B::UpperTriangular) = - (A = adjA.parent; UpperTriangular(lmul!(adjoint(A), triu!(B.data)))) -lmul!(adjA::Adjoint{<:Any,<:Union{UpperTriangular,UnitUpperTriangular}}, B::LowerTriangular) = - (A = adjA.parent; LowerTriangular(lmul!(adjoint(A), tril!(B.data)))) -lmul!(transA::Transpose{<:Any,<:Union{LowerTriangular,UnitLowerTriangular}}, B::UpperTriangular) = - (A = transA.parent; UpperTriangular(lmul!(transpose(A), triu!(B.data)))) -lmul!(transA::Transpose{<:Any,<:Union{UpperTriangular,UnitUpperTriangular}}, B::LowerTriangular) = - (A = transA.parent; LowerTriangular(lmul!(transpose(A), tril!(B.data)))) -ldiv!(adjA::Adjoint{<:Any,<:Union{LowerTriangular,UnitLowerTriangular}}, B::UpperTriangular) = - (A = adjA.parent; UpperTriangular(ldiv!(adjoint(A), triu!(B.data)))) -ldiv!(adjA::Adjoint{<:Any,<:Union{UpperTriangular,UnitUpperTriangular}}, B::LowerTriangular) = - (A = adjA.parent; LowerTriangular(ldiv!(adjoint(A), tril!(B.data)))) -ldiv!(transA::Transpose{<:Any,<:Union{LowerTriangular,UnitLowerTriangular}}, B::UpperTriangular) = - (A = transA.parent; UpperTriangular(ldiv!(transpose(A), triu!(B.data)))) -ldiv!(transA::Transpose{<:Any,<:Union{UpperTriangular,UnitUpperTriangular}}, B::LowerTriangular) = - (A = transA.parent; LowerTriangular(ldiv!(transpose(A), tril!(B.data)))) - -rdiv!(A::UpperTriangular, B::Union{UpperTriangular,UnitUpperTriangular}) = - UpperTriangular(rdiv!(triu!(A.data), B)) -rdiv!(A::LowerTriangular, B::Union{LowerTriangular,UnitLowerTriangular}) = - LowerTriangular(rdiv!(tril!(A.data), B)) - -rmul!(A::UpperTriangular, adjB::Adjoint{<:Any,<:Union{LowerTriangular,UnitLowerTriangular}}) = - (B = adjB.parent; UpperTriangular(rmul!(triu!(A.data), adjoint(B)))) -rmul!(A::LowerTriangular, adjB::Adjoint{<:Any,<:Union{UpperTriangular,UnitUpperTriangular}}) = - (B = adjB.parent; LowerTriangular(rmul!(tril!(A.data), adjoint(B)))) -rmul!(A::UpperTriangular, transB::Transpose{<:Any,<:Union{LowerTriangular,UnitLowerTriangular}}) = - (B = transB.parent; UpperTriangular(rmul!(triu!(A.data), transpose(B)))) -rmul!(A::LowerTriangular, transB::Transpose{<:Any,<:Union{UpperTriangular,UnitUpperTriangular}}) = - (B = transB.parent; LowerTriangular(rmul!(tril!(A.data), transpose(B)))) -rdiv!(A::UpperTriangular, adjB::Adjoint{<:Any,<:Union{LowerTriangular,UnitLowerTriangular}}) = - (B = adjB.parent; UpperTriangular(rdiv!(triu!(A.data), adjoint(B)))) -rdiv!(A::LowerTriangular, adjB::Adjoint{<:Any,<:Union{UpperTriangular,UnitUpperTriangular}}) = - (B = adjB.parent; LowerTriangular(rdiv!(tril!(A.data), adjoint(B)))) -rdiv!(A::UpperTriangular, transB::Transpose{<:Any,<:Union{LowerTriangular,UnitLowerTriangular}}) = - (B = transB.parent; UpperTriangular(rdiv!(triu!(A.data), transpose(B)))) -rdiv!(A::LowerTriangular, transB::Transpose{<:Any,<:Union{UpperTriangular,UnitUpperTriangular}}) = - (B = transB.parent; LowerTriangular(rdiv!(tril!(A.data), transpose(B)))) +function lmul!(adjA::Adjoint{<:Any,<:Union{LowerTriangular,UnitLowerTriangular}}, B::UpperTriangular) + return UpperTriangular(lmul!(adjA, triu!(B.data))) +end +function lmul!(adjA::Adjoint{<:Any,<:Union{UpperTriangular,UnitUpperTriangular}}, B::LowerTriangular) + return LowerTriangular(lmul!(adjA, tril!(B.data))) +end +function lmul!(transA::Transpose{<:Any,<:Union{LowerTriangular,UnitLowerTriangular}}, B::UpperTriangular) + return UpperTriangular(lmul!(transA, triu!(B.data))) +end +function lmul!(transA::Transpose{<:Any,<:Union{UpperTriangular,UnitUpperTriangular}}, B::LowerTriangular) + return LowerTriangular(lmul!(transA, tril!(B.data))) +end +function ldiv!(adjA::Adjoint{<:Any,<:Union{LowerTriangular,UnitLowerTriangular}}, B::UpperTriangular) + return UpperTriangular(ldiv!(adjA, triu!(B.data))) +end +function ldiv!(adjA::Adjoint{<:Any,<:Union{UpperTriangular,UnitUpperTriangular}}, B::LowerTriangular) + return LowerTriangular(ldiv!(adjA, tril!(B.data))) +end +function ldiv!(transA::Transpose{<:Any,<:Union{LowerTriangular,UnitLowerTriangular}}, B::UpperTriangular) + return UpperTriangular(ldiv!(transA, triu!(B.data))) +end +function ldiv!(transA::Transpose{<:Any,<:Union{UpperTriangular,UnitUpperTriangular}}, B::LowerTriangular) + return LowerTriangular(ldiv!(transA, tril!(B.data))) +end + +function rdiv!(A::UpperTriangular, B::Union{UpperTriangular,UnitUpperTriangular}) + return UpperTriangular(rdiv!(triu!(A.data), B)) +end +function rdiv!(A::LowerTriangular, B::Union{LowerTriangular,UnitLowerTriangular}) + return LowerTriangular(rdiv!(tril!(A.data), B)) +end + +function rmul!(A::UpperTriangular, adjB::Adjoint{<:Any,<:Union{LowerTriangular,UnitLowerTriangular}}) + return UpperTriangular(rmul!(triu!(A.data), adjB)) +end +function rmul!(A::LowerTriangular, adjB::Adjoint{<:Any,<:Union{UpperTriangular,UnitUpperTriangular}}) + return LowerTriangular(rmul!(tril!(A.data), adjB)) +end +function rmul!(A::UpperTriangular, transB::Transpose{<:Any,<:Union{LowerTriangular,UnitLowerTriangular}}) + return UpperTriangular(rmul!(triu!(A.data), transB)) +end +function rmul!(A::LowerTriangular, transB::Transpose{<:Any,<:Union{UpperTriangular,UnitUpperTriangular}}) + return LowerTriangular(rmul!(tril!(A.data), transB)) +end +function rdiv!(A::UpperTriangular, adjB::Adjoint{<:Any,<:Union{LowerTriangular,UnitLowerTriangular}}) + return UpperTriangular(rdiv!(triu!(A.data), adjB)) +end +function rdiv!(A::LowerTriangular, adjB::Adjoint{<:Any,<:Union{UpperTriangular,UnitUpperTriangular}}) + return LowerTriangular(rdiv!(tril!(A.data), adjB)) +end +function rdiv!(A::UpperTriangular, transB::Transpose{<:Any,<:Union{LowerTriangular,UnitLowerTriangular}}) + return UpperTriangular(rdiv!(triu!(A.data), transB)) +end +function rdiv!(A::LowerTriangular, transB::Transpose{<:Any,<:Union{UpperTriangular,UnitUpperTriangular}}) + return LowerTriangular(rdiv!(tril!(A.data), transB)) +end # Promotion ## Promotion methods in matmul don't apply to triangular multiplication since diff --git a/stdlib/LinearAlgebra/test/triangular.jl b/stdlib/LinearAlgebra/test/triangular.jl index 5f6ec8771a030..5e8840df081e7 100644 --- a/stdlib/LinearAlgebra/test/triangular.jl +++ b/stdlib/LinearAlgebra/test/triangular.jl @@ -323,6 +323,8 @@ for elty1 in (Float32, Float64, BigFloat, ComplexF32, ComplexF64, Complex{BigFlo # Triangular-Triangualar multiplication and division @test A1*A2 ≈ Matrix(A1)*Matrix(A2) @test transpose(A1)*A2 ≈ transpose(Matrix(A1))*Matrix(A2) + @test transpose(A1)*adjoint(A2) ≈ transpose(Matrix(A1))*adjoint(Matrix(A2)) + @test adjoint(A1)*transpose(A2) ≈ adjoint(Matrix(A1))*transpose(Matrix(A2)) @test A1'A2 ≈ Matrix(A1)'Matrix(A2) @test A1*transpose(A2) ≈ Matrix(A1)*transpose(Matrix(A2)) @test A1*A2' ≈ Matrix(A1)*Matrix(A2)' @@ -340,6 +342,21 @@ for elty1 in (Float32, Float64, BigFloat, ComplexF32, ComplexF64, Complex{BigFlo @test_throws DimensionMismatch transpose(A2) * offsizeA @test_throws DimensionMismatch A2' * offsizeA @test_throws DimensionMismatch A2 * offsizeA + if (uplo1 == uplo2 && elty1 == elty2 != Int && t1 != UnitLowerTriangular && t1 != UnitUpperTriangular) + @test rdiv!(copy(A1), copy(A2)) ≈ A1/A2 ≈ Matrix(A1)/Matrix(A2) + end + if (uplo1 != uplo2 && elty1 == elty2 != Int && t2 != UnitLowerTriangular && t2 != UnitUpperTriangular) + @test lmul!(adjoint(copy(A1)), copy(A2)) ≈ A1'*A2 ≈ Matrix(A1)'*Matrix(A2) + @test lmul!(transpose(copy(A1)), copy(A2)) ≈ transpose(A1)*A2 ≈ transpose(Matrix(A1))*Matrix(A2) + @test ldiv!(adjoint(copy(A1)), copy(A2)) ≈ A1'\A2 ≈ Matrix(A1)'\Matrix(A2) + @test ldiv!(transpose(copy(A1)), copy(A2)) ≈ transpose(A1)\A2 ≈ transpose(Matrix(A1))\Matrix(A2) + end + if (uplo1 != uplo2 && elty1 == elty2 != Int && t1 != UnitLowerTriangular && t1 != UnitUpperTriangular) + @test rmul!(copy(A1), adjoint(copy(A2))) ≈ A1*A2' ≈ Matrix(A1)*Matrix(A2)' + @test rmul!(copy(A1), transpose(copy(A2))) ≈ A1*transpose(A2) ≈ Matrix(A1)*transpose(Matrix(A2)) + @test rdiv!(copy(A1), adjoint(copy(A2))) ≈ A1/A2' ≈ Matrix(A1)/Matrix(A2)' + @test rdiv!(copy(A1), transpose(copy(A2))) ≈ A1/transpose(A2) ≈ Matrix(A1)/transpose(Matrix(A2)) + end end end @@ -368,11 +385,15 @@ for elty1 in (Float32, Float64, BigFloat, ComplexF32, ComplexF64, Complex{BigFlo @test transpose(A1)*B ≈ transpose(Matrix(A1))*B @test A1'B ≈ Matrix(A1)'B @test A1*transpose(B) ≈ Matrix(A1)*transpose(B) + @test adjoint(A1)*transpose(B) ≈ Matrix(A1)'*transpose(B) + @test transpose(A1)*adjoint(B) ≈ transpose(Matrix(A1))*adjoint(B) @test A1*B' ≈ Matrix(A1)*B' @test B*A1 ≈ B*Matrix(A1) @test transpose(B[:,1])*A1 ≈ transpose(B[:,1])*Matrix(A1) @test B[:,1]'A1 ≈ B[:,1]'Matrix(A1) @test transpose(B)*A1 ≈ transpose(B)*Matrix(A1) + @test transpose(B)*adjoint(A1) ≈ transpose(B)*Matrix(A1)' + @test adjoint(B)*transpose(A1) ≈ adjoint(B)*transpose(Matrix(A1)) @test B'A1 ≈ B'Matrix(A1) @test B*transpose(A1) ≈ B*transpose(Matrix(A1)) @test B*A1' ≈ B*Matrix(A1)' @@ -382,16 +403,20 @@ for elty1 in (Float32, Float64, BigFloat, ComplexF32, ComplexF64, Complex{BigFlo @test B'A1' ≈ B'Matrix(A1)' if eltyB == elty1 - @test mul!(similar(B),A1,B) ≈ A1*B - @test mul!(similar(B), A1, adjoint(B)) ≈ A1*B' - @test mul!(similar(B), A1, transpose(B)) ≈ A1*transpose(B) - @test mul!(similar(B), adjoint(A1), B) ≈ A1'*B - @test mul!(similar(B), transpose(A1), B) ≈ transpose(A1)*B + @test mul!(similar(B), A1, B) ≈ Matrix(A1)*B + @test mul!(similar(B), A1, adjoint(B)) ≈ Matrix(A1)*B' + @test mul!(similar(B), A1, transpose(B)) ≈ Matrix(A1)*transpose(B) + @test mul!(similar(B), adjoint(A1), adjoint(B)) ≈ Matrix(A1)'*B' + @test mul!(similar(B), transpose(A1), transpose(B)) ≈ transpose(Matrix(A1))*transpose(B) + @test mul!(similar(B), transpose(A1), adjoint(B)) ≈ transpose(Matrix(A1))*B' + @test mul!(similar(B), adjoint(A1), transpose(B)) ≈ Matrix(A1)'*transpose(B) + @test mul!(similar(B), adjoint(A1), B) ≈ Matrix(A1)'*B + @test mul!(similar(B), transpose(A1), B) ≈ transpose(Matrix(A1))*B # test also vector methods B1 = vec(B[1,:]) - @test mul!(similar(B1),A1,B1) ≈ A1*B1 - @test mul!(similar(B1), adjoint(A1), B1) ≈ A1'*B1 - @test mul!(similar(B1), transpose(A1), B1) ≈ transpose(A1)*B1 + @test mul!(similar(B1), A1, B1) ≈ Matrix(A1)*B1 + @test mul!(similar(B1), adjoint(A1), B1) ≈ Matrix(A1)'*B1 + @test mul!(similar(B1), transpose(A1), B1) ≈ transpose(Matrix(A1))*B1 end #error handling Ann, Bmm, bm = A1, Matrix{eltyB}(undef, n+1, n+1), Vector{eltyB}(undef, n+1)