Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute real matrix logarithm and matrix square root using real arithmetic #39973

Merged
merged 70 commits into from
Mar 20, 2021
Merged
Show file tree
Hide file tree
Changes from 69 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
b5a1e13
Add failing test
sethaxen Mar 1, 2021
391eab5
Add sylvester methods for small matrices
sethaxen Mar 1, 2021
413290c
Add 2x2 real matrix square root
sethaxen Mar 1, 2021
cc60b10
Add real square root of quasitriangular matrix
sethaxen Mar 1, 2021
063b2e9
Simplify 2x2 real square root
sethaxen Mar 3, 2021
f384944
Rename functions to use quasitriu
sethaxen Mar 3, 2021
359c722
Avoid NaNs when eigenvalues are all zero
sethaxen Mar 3, 2021
5d2c60a
Reuse ranges
sethaxen Mar 3, 2021
7e980ed
Add clarifying comments
sethaxen Mar 3, 2021
2bb972f
Unify real and complex matrix square root
sethaxen Mar 3, 2021
e465ead
Add reference for real sqrt
sethaxen Mar 3, 2021
539b54d
Move quasitriu auxiliary functions to triangular.jl
sethaxen Mar 3, 2021
7c5ec4a
Ensure loops are type-stable and use simd
sethaxen Mar 4, 2021
e1e0484
Remove duplicate computation
sethaxen Mar 4, 2021
56a6cf9
Correctly promote for dimensionful A
sethaxen Mar 7, 2021
6f45ecb
Use simd directive
sethaxen Mar 8, 2021
20a4942
Test that UpperTriangular is returned by sqrt
sethaxen Mar 8, 2021
e39698f
Test sqrt for UnitUpperTriangular
sethaxen Mar 8, 2021
96922de
Test that return type is complex when input type is
sethaxen Mar 8, 2021
f175738
Test that output is complex when input is
sethaxen Mar 8, 2021
a47ccf4
Add failing test
sethaxen Mar 9, 2021
eb589cd
Separate type-stable from type-unstable part
sethaxen Mar 9, 2021
46561e1
Use generic sqrt_quasitriu for sqrt triu
sethaxen Mar 9, 2021
4c0799b
Avoid redundant matmul
sethaxen Mar 9, 2021
a829f81
Clarify comment
sethaxen Mar 9, 2021
c19b557
Return complex output for complex input
sethaxen Mar 9, 2021
0b53d89
Call log_quasitriu
sethaxen Mar 9, 2021
2b61405
Add failing test for log type-inferrability
sethaxen Mar 9, 2021
fd3624b
Realify or complexify as necessary
sethaxen Mar 9, 2021
0ba02c1
Call sqrt_quasitriu directly
sethaxen Mar 9, 2021
075c3bc
Refactor sqrt_diag!
sethaxen Mar 9, 2021
39411d1
Simplify utility function
sethaxen Mar 9, 2021
7141bea
Add comment
sethaxen Mar 9, 2021
435c354
Compute accurate block-diagonal
sethaxen Mar 9, 2021
7e9d0ef
Compute superdiagonal for quasi triu A0
sethaxen Mar 9, 2021
86e502e
Compute accurate block superdiagonal
sethaxen Mar 9, 2021
889fd65
Avoid full LU decomposition in inner loop
sethaxen Mar 9, 2021
3c2f567
Avoid promotion to improve type-stability
sethaxen Mar 9, 2021
2813dd1
Modify return type if necessary
sethaxen Mar 9, 2021
023e131
Clarify comment
sethaxen Mar 9, 2021
de99f00
Add comments
sethaxen Mar 9, 2021
9775592
Call log_quasitriu on quasitriu matrices
sethaxen Mar 9, 2021
e15b92e
Document quasi-triangular algorithm
sethaxen Mar 9, 2021
0ac37df
Remove test
sethaxen Mar 9, 2021
e817a37
Rearrange definition
sethaxen Mar 9, 2021
5946cbe
Add compatibility for unit triangular matrices
sethaxen Mar 9, 2021
2ee1585
Release constraints on tests
sethaxen Mar 9, 2021
37d0344
Merge branch 'master' into reallog
sethaxen Mar 9, 2021
23becc5
Separate copying of A from log computation
sethaxen Mar 10, 2021
ff9d8d7
Revert "Separate copying of A from log computation"
sethaxen Mar 10, 2021
cfa6ea1
Use Givens rotations
sethaxen Mar 10, 2021
e14a819
Compute Schur in-place when possible
sethaxen Mar 10, 2021
5d37a0f
Always allocate a copy
sethaxen Mar 10, 2021
3278edd
Fix block indexing
sethaxen Mar 10, 2021
5b96c72
Compute sqrt in-place
sethaxen Mar 10, 2021
ae6daae
Overwrite AmI
sethaxen Mar 10, 2021
f673914
Reduce allocations in Pade approximation
sethaxen Mar 10, 2021
c96a05b
Use T
sethaxen Mar 11, 2021
9db29fc
Don't unnecessarily unwrap
sethaxen Mar 11, 2021
0c475bc
Test remaining log branches
sethaxen Mar 11, 2021
96b6f3d
Add additional matrix square root tests
sethaxen Mar 11, 2021
d1da048
Separate type-unstable from type-stable part
sethaxen Mar 11, 2021
69e298d
Use Ref instead of a Vector
sethaxen Mar 19, 2021
9785686
Eliminate allocation in checksquare
sethaxen Mar 19, 2021
de1682c
Refactor param choosing code to own function
sethaxen Mar 19, 2021
27d2063
Comment section
sethaxen Mar 19, 2021
def31ba
Use more descriptive variable name
sethaxen Mar 19, 2021
cb941b1
Reuse temporaries
sethaxen Mar 19, 2021
d1d7095
Add reference
sethaxen Mar 19, 2021
de5fdd7
More accurately describe condition
sethaxen Mar 20, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 82 additions & 43 deletions stdlib/LinearAlgebra/src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ function rcswap!(i::Integer, j::Integer, X::StridedMatrix{<:Number})
end

"""
log(A{T}::StridedMatrix{T})
log(A::StridedMatrix)

If `A` has no negative real eigenvalue, compute the principal matrix logarithm of `A`, i.e.
the unique matrix ``X`` such that ``e^X = A`` and ``-\\pi < Im(\\lambda) < \\pi`` for all
Expand All @@ -688,9 +688,10 @@ matrix function is returned whenever possible.

If `A` is symmetric or Hermitian, its eigendecomposition ([`eigen`](@ref)) is
used, if `A` is triangular an improved version of the inverse scaling and squaring method is
employed (see [^AH12] and [^AHR13]). For general matrices, the complex Schur form
([`schur`](@ref)) is computed and the triangular algorithm is used on the
triangular factor.
employed (see [^AH12] and [^AHR13]). For general matrices, if a real logarithm exists, then
the real Schur form is computed. Otherwise, the complex Schur form is computed. Then
the upper (quasi-)triangular algorithm in [^AHR13] is used on the upper (quasi-)triangular
factor.

[^AH12]: Awad H. Al-Mohy and Nicholas J. Higham, "Improved inverse scaling and squaring algorithms for the matrix logarithm", SIAM Journal on Scientific Computing, 34(4), 2012, C153-C169. [doi:10.1137/110852553](https://doi.org/10.1137/110852553)

Expand All @@ -713,27 +714,28 @@ function log(A::StridedMatrix)
# If possible, use diagonalization
if ishermitian(A)
logHermA = log(Hermitian(A))
return isa(logHermA, Hermitian) ? copytri!(parent(logHermA), 'U', true) : parent(logHermA)
end

# Use Schur decomposition
n = checksquare(A)
if istriu(A)
return triu!(parent(log(UpperTriangular(complex(A)))))
else
if isreal(A)
SchurF = schur(real(A))
return ishermitian(logHermA) ? copytri!(parent(logHermA), 'U', true) : parent(logHermA)
elseif istriu(A)
return triu!(parent(log(UpperTriangular(A))))
elseif isreal(A)
SchurF = schur(real(A))
if istriu(SchurF.T)
logA = SchurF.Z * log(UpperTriangular(SchurF.T)) * SchurF.Z'
else
SchurF = schur(A)
end
if !istriu(SchurF.T)
SchurS = schur(complex(SchurF.T))
logT = SchurS.Z * log(UpperTriangular(SchurS.T)) * SchurS.Z'
return SchurF.Z * logT * SchurF.Z'
else
R = log(UpperTriangular(complex(SchurF.T)))
return SchurF.Z * R * SchurF.Z'
# real log exists whenever all eigenvalues are positive
is_log_real = !any(x -> isreal(x) && real(x) ≤ 0, SchurF.values)
if is_log_real
logA = SchurF.Z * log_quasitriu(SchurF.T) * SchurF.Z'
else
SchurS = schur!(complex(SchurF.T))
Z = SchurF.Z * SchurS.Z
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you square Z here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm actually composing two different Zs to eliminate one matrix multiplication. In the original implementation, it was essentially FZ * SZ * log(T) * SZ' * FZ'

logA = Z * log(UpperTriangular(SchurS.T)) * Z'
end
end
return eltype(A) <: Complex ? complex(logA) : logA
else
SchurF = schur(A)
return SchurF.vectors * log(UpperTriangular(SchurF.T)) * SchurF.vectors'
end
end

Expand All @@ -755,13 +757,21 @@ defaults to machine precision scaled by `size(A,1)`.
Otherwise, the square root is determined by means of the
Björck-Hammarling method [^BH83], which computes the complex Schur form ([`schur`](@ref))
and then the complex square root of the triangular factor.
If a real square root exists, then an extension of this method [^H87] that computes the real
Schur form and then the real square root of the quasi-triangular factor is instead used.

[^BH83]:

Åke Björck and Sven Hammarling, "A Schur method for the square root of a matrix",
Linear Algebra and its Applications, 52-53, 1983, 127-140.
[doi:10.1016/0024-3795(83)80010-X](https://doi.org/10.1016/0024-3795(83)80010-X)

[^H87]:

Nicholas J. Higham, "Computing real square roots of a real matrix",
Linear Algebra and its Applications, 88-89, 1987, 405-430.
[doi:10.1016/0024-3795(87)90118-2](https://doi.org/10.1016/0024-3795(87)90118-2)

# Examples
```jldoctest
julia> A = [4 0; 0 4]
Expand All @@ -775,31 +785,32 @@ julia> sqrt(A)
0.0 2.0
```
"""
function sqrt(A::StridedMatrix{<:Real})
if issymmetric(A)
return copytri!(parent(sqrt(Symmetric(A))), 'U')
end
n = checksquare(A)
if istriu(A)
return triu!(parent(sqrt(UpperTriangular(A))))
else
SchurF = schur(complex(A))
R = triu!(parent(sqrt(UpperTriangular(SchurF.T)))) # unwrapping unnecessary?
return SchurF.vectors * R * SchurF.vectors'
end
end
function sqrt(A::StridedMatrix{<:Complex})
function sqrt(A::StridedMatrix{T}) where {T<:Union{Real,Complex}}
if ishermitian(A)
sqrtHermA = sqrt(Hermitian(A))
return isa(sqrtHermA, Hermitian) ? copytri!(parent(sqrtHermA), 'U', true) : parent(sqrtHermA)
end
n = checksquare(A)
if istriu(A)
return ishermitian(sqrtHermA) ? copytri!(parent(sqrtHermA), 'U', true) : parent(sqrtHermA)
elseif istriu(A)
return triu!(parent(sqrt(UpperTriangular(A))))
elseif isreal(A)
SchurF = schur(real(A))
if istriu(SchurF.T)
sqrtA = SchurF.Z * sqrt(UpperTriangular(SchurF.T)) * SchurF.Z'
else
# real sqrt exists whenever no eigenvalues are negative
is_sqrt_real = !any(x -> isreal(x) && real(x) < 0, SchurF.values)
# sqrt_quasitriu uses LAPACK functions for non-triu inputs
if typeof(sqrt(zero(T))) <: BlasFloat && is_sqrt_real
sqrtA = SchurF.Z * sqrt_quasitriu(SchurF.T) * SchurF.Z'
else
SchurS = schur!(complex(SchurF.T))
Z = SchurF.Z * SchurS.Z
sqrtA = Z * sqrt(UpperTriangular(SchurS.T)) * Z'
end
end
return eltype(A) <: Complex ? complex(sqrtA) : sqrtA
else
SchurF = schur(A)
R = triu!(parent(sqrt(UpperTriangular(SchurF.T)))) # unwrapping unnecessary?
return SchurF.vectors * R * SchurF.vectors'
return SchurF.vectors * sqrt(UpperTriangular(SchurF.T)) * SchurF.vectors'
end
end

Expand Down Expand Up @@ -1526,6 +1537,34 @@ function sylvester(A::StridedMatrix{T},B::StridedMatrix{T},C::StridedMatrix{T})
end
sylvester(A::StridedMatrix{T}, B::StridedMatrix{T}, C::StridedMatrix{T}) where {T<:Integer} = sylvester(float(A), float(B), float(C))

Base.@propagate_inbounds function _sylvester_2x1!(A, B, C)
b = B[1]
a21, a12 = A[2, 1], A[1, 2]
m11 = b + A[1, 1]
m22 = b + A[2, 2]
d = m11 * m22 - a12 * a21
c1, c2 = C
C[1] = (a12 * c2 - m22 * c1) / d
C[2] = (a21 * c1 - m11 * c2) / d
return C
end
Base.@propagate_inbounds function _sylvester_1x2!(A, B, C)
a = A[1]
b21, b12 = B[2, 1], B[1, 2]
m11 = a + B[1, 1]
m22 = a + B[2, 2]
d = m11 * m22 - b21 * b12
c1, c2 = C
C[1] = (b21 * c2 - m22 * c1) / d
C[2] = (b12 * c1 - m11 * c2) / d
return C
end
function _sylvester_2x2!(A, B, C)
_, scale = LAPACK.trsyl!('N', 'N', A, B, C)
rmul!(C, -inv(scale))
return C
end

sylvester(a::Union{Real,Complex}, b::Union{Real,Complex}, c::Union{Real,Complex}) = -c / (a + b)

# AX + XA' + C = 0
Expand Down
6 changes: 3 additions & 3 deletions stdlib/LinearAlgebra/src/lapack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6449,15 +6449,15 @@ for (fn, elty, relty) in ((:dtrsyl_, :Float64, :Float64),
B::AbstractMatrix{$elty}, C::AbstractMatrix{$elty}, isgn::Int=1)
require_one_based_indexing(A, B, C)
chkstride1(A, B, C)
m, n = checksquare(A, B)
m, n = checksquare(A), checksquare(B)
lda = max(1, stride(A, 2))
ldb = max(1, stride(B, 2))
m1, n1 = size(C)
if m != m1 || n != n1
throw(DimensionMismatch("dimensions of A, ($m,$n), and C, ($m1,$n1), must match"))
end
ldc = max(1, stride(C, 2))
scale = Vector{$relty}(undef, 1)
scale = Ref{$relty}()
info = Ref{BlasInt}()
ccall((@blasfunc($fn), libblastrampoline), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ref{BlasInt},
Expand All @@ -6467,7 +6467,7 @@ for (fn, elty, relty) in ((:dtrsyl_, :Float64, :Float64),
A, lda, B, ldb, C, ldc,
scale, info, 1, 1)
chklapackerror(info[])
C, scale[1]
C, scale[]
end
end
end
Expand Down
Loading