Skip to content

Commit

Permalink
Compute real matrix logarithm and matrix square root using real arith…
Browse files Browse the repository at this point in the history
…metic (JuliaLang#39973)

* Add failing test

* Add sylvester methods for small matrices

* Add 2x2 real matrix square root

* Add real square root of quasitriangular matrix

* Simplify 2x2 real square root

* Rename functions to use quasitriu

* Avoid NaNs when eigenvalues are all zero

* Reuse ranges

* Add clarifying comments

* Unify real and complex matrix square root

* Add reference for real sqrt

* Move quasitriu auxiliary functions to triangular.jl

* Ensure loops are type-stable and use simd

* Remove duplicate computation

* Correctly promote for dimensionful A

* Use simd directive

* Test that UpperTriangular is returned by sqrt

* Test sqrt for UnitUpperTriangular

* Test that return type is complex when input type is

* Test that output is complex when input is

* Add failing test

* Separate type-stable from type-unstable part

* Use generic sqrt_quasitriu for sqrt triu

* Avoid redundant matmul

* Clarify comment

* Return complex output for complex input

* Call log_quasitriu

* Add failing test for log type-inferrability

* Realify or complexify as necessary

* Call sqrt_quasitriu directly

* Refactor sqrt_diag!

* Simplify utility function

* Add comment

* Compute accurate block-diagonal

* Compute superdiagonal for quasi triu A0

* Compute accurate block superdiagonal

* Avoid full LU decomposition in inner loop

* Avoid promotion to improve type-stability

* Modify return type if necessary

* Clarify comment

* Add comments

* Call log_quasitriu on quasitriu matrices

* Document quasi-triangular algorithm

* Remove test

This matrix has eigenvalues to close to zero that its eltype is not stable

* Rearrange definition

* Add compatibility for unit triangular matrices

* Release constraints on tests

* Separate copying of A from log computation

* Revert "Separate copying of A from log computation"

This reverts commit 23becc5.

* Use Givens rotations

* Compute Schur in-place when possible

* Always allocate a copy

* Fix block indexing

* Compute sqrt in-place

* Overwrite AmI

* Reduce allocations in Pade approximation

* Use T

* Don't unnecessarily unwrap

* Test remaining log branches

* Add additional matrix square root tests

* Separate type-unstable from type-stable part

This substantially reduces allocations for some reason

* Use Ref instead of a Vector

* Eliminate allocation in checksquare

* Refactor param choosing code to own function

* Comment section

* Use more descriptive variable name

* Reuse temporaries

* Add reference

* More accurately describe condition
  • Loading branch information
sethaxen authored and ElOceanografo committed May 4, 2021
1 parent f0e13fa commit 59c1ce9
Show file tree
Hide file tree
Showing 5 changed files with 743 additions and 187 deletions.
125 changes: 82 additions & 43 deletions stdlib/LinearAlgebra/src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ function rcswap!(i::Integer, j::Integer, X::StridedMatrix{<:Number})
end

"""
log(A{T}::StridedMatrix{T})
log(A::StridedMatrix)
If `A` has no negative real eigenvalue, compute the principal matrix logarithm of `A`, i.e.
the unique matrix ``X`` such that ``e^X = A`` and ``-\\pi < Im(\\lambda) < \\pi`` for all
Expand All @@ -688,9 +688,10 @@ matrix function is returned whenever possible.
If `A` is symmetric or Hermitian, its eigendecomposition ([`eigen`](@ref)) is
used, if `A` is triangular an improved version of the inverse scaling and squaring method is
employed (see [^AH12] and [^AHR13]). For general matrices, the complex Schur form
([`schur`](@ref)) is computed and the triangular algorithm is used on the
triangular factor.
employed (see [^AH12] and [^AHR13]). If `A` is real with no negative eigenvalues, then
the real Schur form is computed. Otherwise, the complex Schur form is computed. Then
the upper (quasi-)triangular algorithm in [^AHR13] is used on the upper (quasi-)triangular
factor.
[^AH12]: Awad H. Al-Mohy and Nicholas J. Higham, "Improved inverse scaling and squaring algorithms for the matrix logarithm", SIAM Journal on Scientific Computing, 34(4), 2012, C153-C169. [doi:10.1137/110852553](https://doi.org/10.1137/110852553)
Expand All @@ -713,27 +714,28 @@ function log(A::StridedMatrix)
# If possible, use diagonalization
if ishermitian(A)
logHermA = log(Hermitian(A))
return isa(logHermA, Hermitian) ? copytri!(parent(logHermA), 'U', true) : parent(logHermA)
end

# Use Schur decomposition
n = checksquare(A)
if istriu(A)
return triu!(parent(log(UpperTriangular(complex(A)))))
else
if isreal(A)
SchurF = schur(real(A))
return ishermitian(logHermA) ? copytri!(parent(logHermA), 'U', true) : parent(logHermA)
elseif istriu(A)
return triu!(parent(log(UpperTriangular(A))))
elseif isreal(A)
SchurF = schur(real(A))
if istriu(SchurF.T)
logA = SchurF.Z * log(UpperTriangular(SchurF.T)) * SchurF.Z'
else
SchurF = schur(A)
end
if !istriu(SchurF.T)
SchurS = schur(complex(SchurF.T))
logT = SchurS.Z * log(UpperTriangular(SchurS.T)) * SchurS.Z'
return SchurF.Z * logT * SchurF.Z'
else
R = log(UpperTriangular(complex(SchurF.T)))
return SchurF.Z * R * SchurF.Z'
# real log exists whenever all eigenvalues are positive
is_log_real = !any(x -> isreal(x) && real(x) 0, SchurF.values)
if is_log_real
logA = SchurF.Z * log_quasitriu(SchurF.T) * SchurF.Z'
else
SchurS = schur!(complex(SchurF.T))
Z = SchurF.Z * SchurS.Z
logA = Z * log(UpperTriangular(SchurS.T)) * Z'
end
end
return eltype(A) <: Complex ? complex(logA) : logA
else
SchurF = schur(A)
return SchurF.vectors * log(UpperTriangular(SchurF.T)) * SchurF.vectors'
end
end

Expand All @@ -755,13 +757,21 @@ defaults to machine precision scaled by `size(A,1)`.
Otherwise, the square root is determined by means of the
Björck-Hammarling method [^BH83], which computes the complex Schur form ([`schur`](@ref))
and then the complex square root of the triangular factor.
If a real square root exists, then an extension of this method [^H87] that computes the real
Schur form and then the real square root of the quasi-triangular factor is instead used.
[^BH83]:
Åke Björck and Sven Hammarling, "A Schur method for the square root of a matrix",
Linear Algebra and its Applications, 52-53, 1983, 127-140.
[doi:10.1016/0024-3795(83)80010-X](https://doi.org/10.1016/0024-3795(83)80010-X)
[^H87]:
Nicholas J. Higham, "Computing real square roots of a real matrix",
Linear Algebra and its Applications, 88-89, 1987, 405-430.
[doi:10.1016/0024-3795(87)90118-2](https://doi.org/10.1016/0024-3795(87)90118-2)
# Examples
```jldoctest
julia> A = [4 0; 0 4]
Expand All @@ -775,31 +785,32 @@ julia> sqrt(A)
0.0 2.0
```
"""
function sqrt(A::StridedMatrix{<:Real})
if issymmetric(A)
return copytri!(parent(sqrt(Symmetric(A))), 'U')
end
n = checksquare(A)
if istriu(A)
return triu!(parent(sqrt(UpperTriangular(A))))
else
SchurF = schur(complex(A))
R = triu!(parent(sqrt(UpperTriangular(SchurF.T)))) # unwrapping unnecessary?
return SchurF.vectors * R * SchurF.vectors'
end
end
function sqrt(A::StridedMatrix{<:Complex})
function sqrt(A::StridedMatrix{T}) where {T<:Union{Real,Complex}}
if ishermitian(A)
sqrtHermA = sqrt(Hermitian(A))
return isa(sqrtHermA, Hermitian) ? copytri!(parent(sqrtHermA), 'U', true) : parent(sqrtHermA)
end
n = checksquare(A)
if istriu(A)
return ishermitian(sqrtHermA) ? copytri!(parent(sqrtHermA), 'U', true) : parent(sqrtHermA)
elseif istriu(A)
return triu!(parent(sqrt(UpperTriangular(A))))
elseif isreal(A)
SchurF = schur(real(A))
if istriu(SchurF.T)
sqrtA = SchurF.Z * sqrt(UpperTriangular(SchurF.T)) * SchurF.Z'
else
# real sqrt exists whenever no eigenvalues are negative
is_sqrt_real = !any(x -> isreal(x) && real(x) < 0, SchurF.values)
# sqrt_quasitriu uses LAPACK functions for non-triu inputs
if typeof(sqrt(zero(T))) <: BlasFloat && is_sqrt_real
sqrtA = SchurF.Z * sqrt_quasitriu(SchurF.T) * SchurF.Z'
else
SchurS = schur!(complex(SchurF.T))
Z = SchurF.Z * SchurS.Z
sqrtA = Z * sqrt(UpperTriangular(SchurS.T)) * Z'
end
end
return eltype(A) <: Complex ? complex(sqrtA) : sqrtA
else
SchurF = schur(A)
R = triu!(parent(sqrt(UpperTriangular(SchurF.T)))) # unwrapping unnecessary?
return SchurF.vectors * R * SchurF.vectors'
return SchurF.vectors * sqrt(UpperTriangular(SchurF.T)) * SchurF.vectors'
end
end

Expand Down Expand Up @@ -1526,6 +1537,34 @@ function sylvester(A::StridedMatrix{T},B::StridedMatrix{T},C::StridedMatrix{T})
end
sylvester(A::StridedMatrix{T}, B::StridedMatrix{T}, C::StridedMatrix{T}) where {T<:Integer} = sylvester(float(A), float(B), float(C))

Base.@propagate_inbounds function _sylvester_2x1!(A, B, C)
b = B[1]
a21, a12 = A[2, 1], A[1, 2]
m11 = b + A[1, 1]
m22 = b + A[2, 2]
d = m11 * m22 - a12 * a21
c1, c2 = C
C[1] = (a12 * c2 - m22 * c1) / d
C[2] = (a21 * c1 - m11 * c2) / d
return C
end
Base.@propagate_inbounds function _sylvester_1x2!(A, B, C)
a = A[1]
b21, b12 = B[2, 1], B[1, 2]
m11 = a + B[1, 1]
m22 = a + B[2, 2]
d = m11 * m22 - b21 * b12
c1, c2 = C
C[1] = (b21 * c2 - m22 * c1) / d
C[2] = (b12 * c1 - m11 * c2) / d
return C
end
function _sylvester_2x2!(A, B, C)
_, scale = LAPACK.trsyl!('N', 'N', A, B, C)
rmul!(C, -inv(scale))
return C
end

sylvester(a::Union{Real,Complex}, b::Union{Real,Complex}, c::Union{Real,Complex}) = -c / (a + b)

# AX + XA' + C = 0
Expand Down
6 changes: 3 additions & 3 deletions stdlib/LinearAlgebra/src/lapack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6449,15 +6449,15 @@ for (fn, elty, relty) in ((:dtrsyl_, :Float64, :Float64),
B::AbstractMatrix{$elty}, C::AbstractMatrix{$elty}, isgn::Int=1)
require_one_based_indexing(A, B, C)
chkstride1(A, B, C)
m, n = checksquare(A, B)
m, n = checksquare(A), checksquare(B)
lda = max(1, stride(A, 2))
ldb = max(1, stride(B, 2))
m1, n1 = size(C)
if m != m1 || n != n1
throw(DimensionMismatch("dimensions of A, ($m,$n), and C, ($m1,$n1), must match"))
end
ldc = max(1, stride(C, 2))
scale = Vector{$relty}(undef, 1)
scale = Ref{$relty}()
info = Ref{BlasInt}()
ccall((@blasfunc($fn), libblastrampoline), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ref{BlasInt},
Expand All @@ -6467,7 +6467,7 @@ for (fn, elty, relty) in ((:dtrsyl_, :Float64, :Float64),
A, lda, B, ldb, C, ldc,
scale, info, 1, 1)
chklapackerror(info[])
C, scale[1]
C, scale[]
end
end
end
Expand Down
Loading

0 comments on commit 59c1ce9

Please sign in to comment.