Skip to content

Commit

Permalink
Fix rdiv of complex lhs by real factorizations (#50671)
Browse files Browse the repository at this point in the history
Co-authored-by: Martin Holters <[email protected]>
(cherry picked from commit 210c5b5)
  • Loading branch information
dkarrasch authored and KristofferC committed Aug 10, 2023
1 parent 4799388 commit 6a8559d
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 30 deletions.
56 changes: 30 additions & 26 deletions stdlib/LinearAlgebra/src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,12 @@ function Base.show(io::IO, ::MIME"text/plain", x::TransposeFactorization)
show(io, MIME"text/plain"(), parent(x))
end

function (\)(F::Factorization, B::AbstractVecOrMat)
require_one_based_indexing(B)
TFB = typeof(oneunit(eltype(F)) \ oneunit(eltype(B)))
ldiv!(F, copy_similar(B, TFB))
end
(\)(F::TransposeFactorization, B::AbstractVecOrMat) = conj!(adjoint(F.parent) \ conj.(B))
# With a real lhs and complex rhs with the same precision, we can reinterpret
# the complex rhs as a real rhs with twice the number of columns or rows
function (\)(F::Factorization{T}, B::VecOrMat{Complex{T}}) where {T<:BlasReal}
Expand All @@ -151,32 +157,6 @@ end
(\)(F::AdjointFactorization{T}, B::VecOrMat{Complex{T}}) where {T<:BlasReal} =
@invoke \(F::typeof(F), B::VecOrMat)

function (/)(B::VecOrMat{Complex{T}}, F::Factorization{T}) where {T<:BlasReal}
require_one_based_indexing(B)
x = rdiv!(copy(reinterpret(T, B)), F)
return copy(reinterpret(Complex{T}, x))
end
# don't do the reinterpretation for [Adjoint/Transpose]Factorization
(/)(B::VecOrMat{Complex{T}}, F::TransposeFactorization{T}) where {T<:BlasReal} =
conj!(adjoint(parent(F)) \ conj.(B))
(/)(B::VecOrMat{Complex{T}}, F::AdjointFactorization{T}) where {T<:BlasReal} =
@invoke /(B::VecOrMat{Complex{T}}, F::Factorization{T})

function (\)(F::Factorization, B::AbstractVecOrMat)
require_one_based_indexing(B)
TFB = typeof(oneunit(eltype(F)) \ oneunit(eltype(B)))
ldiv!(F, copy_similar(B, TFB))
end
(\)(F::TransposeFactorization, B::AbstractVecOrMat) = conj!(adjoint(F.parent) \ conj.(B))

function (/)(B::AbstractMatrix, F::Factorization)
require_one_based_indexing(B)
TFB = typeof(oneunit(eltype(B)) / oneunit(eltype(F)))
rdiv!(copy_similar(B, TFB), F)
end
(/)(A::AbstractMatrix, F::AdjointFactorization) = adjoint(adjoint(F) \ adjoint(A))
(/)(A::AbstractMatrix, F::TransposeFactorization) = transpose(transpose(F) \ transpose(A))

function ldiv!(Y::AbstractVector, A::Factorization, B::AbstractVector)
require_one_based_indexing(Y, B)
m, n = size(A)
Expand All @@ -200,3 +180,27 @@ function ldiv!(Y::AbstractMatrix, A::Factorization, B::AbstractMatrix)
return ldiv!(A, Y)
end
end

function (/)(B::AbstractMatrix, F::Factorization)
require_one_based_indexing(B)
TFB = typeof(oneunit(eltype(B)) / oneunit(eltype(F)))
rdiv!(copy_similar(B, TFB), F)
end
# reinterpretation trick for complex lhs and real factorization
function (/)(B::Union{Matrix{Complex{T}},AdjOrTrans{Complex{T},Vector{Complex{T}}}}, F::Factorization{T}) where {T<:BlasReal}
require_one_based_indexing(B)
x = rdiv!(copy(reinterpret(T, B)), F)
return copy(reinterpret(Complex{T}, x))
end
# don't do the reinterpretation for [Adjoint/Transpose]Factorization
(/)(B::Union{Matrix{Complex{T}},AdjOrTrans{Complex{T},Vector{Complex{T}}}}, F::TransposeFactorization{T}) where {T<:BlasReal} =
@invoke /(B::AbstractMatrix, F::Factorization)
(/)(B::Matrix{Complex{T}}, F::AdjointFactorization{T}) where {T<:BlasReal} =
@invoke /(B::AbstractMatrix, F::Factorization)
(/)(B::Adjoint{Complex{T},Vector{Complex{T}}}, F::AdjointFactorization{T}) where {T<:BlasReal} =
(F' \ B')'
(/)(B::Transpose{Complex{T},Vector{Complex{T}}}, F::TransposeFactorization{T}) where {T<:BlasReal} =
transpose(transpose(F) \ transpose(B))

rdiv!(B::AbstractMatrix, A::TransposeFactorization) = transpose(ldiv!(A.parent, transpose(B)))
rdiv!(B::AbstractMatrix, A::AdjointFactorization) = adjoint(ldiv!(A.parent, adjoint(B)))
1 change: 0 additions & 1 deletion stdlib/LinearAlgebra/src/hessenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,6 @@ function rdiv!(B::AbstractVecOrMat{<:Complex}, F::Hessenberg{<:Complex,<:Any,<:A
end

ldiv!(F::AdjointFactorization{<:Any,<:Hessenberg}, B::AbstractVecOrMat) = rdiv!(B', F')'
rdiv!(B::AbstractMatrix, F::AdjointFactorization{<:Any,<:Hessenberg}) = ldiv!(F', B')'

det(F::Hessenberg) = det(F.H; shift=F.μ)
logabsdet(F::Hessenberg) = logabsdet(F.H; shift=F.μ)
Expand Down
2 changes: 0 additions & 2 deletions stdlib/LinearAlgebra/src/lu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -709,8 +709,6 @@ function ldiv!(adjA::AdjointFactorization{<:Any,<:LU{T,Tridiagonal{T,V}}}, B::Ab
end

rdiv!(B::AbstractMatrix, A::LU) = transpose(ldiv!(transpose(A), transpose(B)))
rdiv!(B::AbstractMatrix, A::TransposeFactorization{<:Any,<:LU}) = transpose(ldiv!(A.parent, transpose(B)))
rdiv!(B::AbstractMatrix, A::AdjointFactorization{<:Any,<:LU}) = adjoint(ldiv!(A.parent, adjoint(B)))

# Conversions
AbstractMatrix(F::LU) = (F.L * F.U)[invperm(F.p),:]
Expand Down
4 changes: 3 additions & 1 deletion stdlib/LinearAlgebra/test/hessenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,10 @@ let n = 10
@test H \ B A \ B H \ complex(B)
@test (H - I) \ B (A - I) \ B
@test (H - (3+4im)I) \ B (A - (3+4im)I) \ B
@test b' / H b' / A complex.(b') / H
@test b' / H b' / A complex(b') / H
@test transpose(b) / H transpose(b) / A transpose(complex(b)) / H
@test B' / H B' / A complex(B') / H
@test b' / H' complex(b)' / H'
@test B' / (H - I) B' / (A - I)
@test B' / (H - (3+4im)I) B' / (A - (3+4im)I)
@test (H - (3+4im)I)' \ B (A - (3+4im)I)' \ B
Expand Down
9 changes: 9 additions & 0 deletions stdlib/LinearAlgebra/test/lu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,15 @@ end
B = randn(elty, 5, 5)
@test rdiv!(transform(A), transform(lu(B))) transform(C) / transform(B)
end
for elty in (Float32, Float64, ComplexF64), transF in (identity, transpose),
transB in (transpose, adjoint), transT in (identity, complex)
A = randn(elty, 5, 5)
F = lu(A)
b = randn(transT(elty), 5)
@test rdiv!(transB(copy(b)), transF(F)) transB(b) / transF(F) transB(b) / transF(A)
B = randn(transT(elty), 5, 5)
@test rdiv!(copy(B), transF(F)) B / transF(F) B / transF(A)
end
end

@testset "transpose(A) / lu(B)' should not overwrite A (#36657)" begin
Expand Down

0 comments on commit 6a8559d

Please sign in to comment.