Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rules for dense matrix exponential #351

Merged
merged 34 commits into from
Jan 20, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
6e8358c
Add matfun.jl file
sethaxen Jan 18, 2021
f987f35
Add matfun docstrings
sethaxen Jan 18, 2021
e6b92c9
Add exp matrix function
sethaxen Jan 18, 2021
7624ee5
At least store one intermediate
sethaxen Jan 18, 2021
a7792a5
Test exp!
sethaxen Jan 18, 2021
6d6b4cb
Make pullback type-inferrable
sethaxen Jan 18, 2021
3645d75
Add clearer test label
sethaxen Jan 18, 2021
937f2ac
Create as hermitian
sethaxen Jan 18, 2021
b48204c
Test rrule
sethaxen Jan 18, 2021
2c19bba
Add comment about relationship between pushforward and pullback
sethaxen Jan 18, 2021
58f6005
Add header
sethaxen Jan 18, 2021
6ee1759
Add reference to Frechet deriv paper
sethaxen Jan 18, 2021
b1a2980
Run JuliaFormatter
sethaxen Jan 18, 2021
e860b3e
Reduce comment spacing from code
sethaxen Jan 18, 2021
8f665ac
Update src/rulesets/LinearAlgebra/matfun.jl
sethaxen Jan 18, 2021
9e565ae
Correctly handle balancing
sethaxen Jan 18, 2021
71134fd
Test imbalanced matrix A
sethaxen Jan 18, 2021
bd48565
Increment version number
sethaxen Jan 18, 2021
062b11d
Merge branch 'exp2' of https://github.com/sethaxen/ChainRules.jl into…
sethaxen Jan 18, 2021
dc1b1ab
Apply suggestions from code review
sethaxen Jan 19, 2021
57aea17
Change signature of _matfun_frechet
sethaxen Jan 20, 2021
e2e6605
Give math for Frechet derivative
sethaxen Jan 20, 2021
976af09
Change Frechet notation
sethaxen Jan 20, 2021
d7d20ba
Add _matfun_frechet_adjoint
sethaxen Jan 20, 2021
b0ae61c
Simplify hermitian code
sethaxen Jan 20, 2021
62b963b
Correct comment
sethaxen Jan 20, 2021
87e4c53
Remove comments
sethaxen Jan 20, 2021
9bd06b1
Use abbreviated SHA
sethaxen Jan 20, 2021
2ed06e6
Link
sethaxen Jan 20, 2021
156e6f5
Update comment
sethaxen Jan 20, 2021
9a63d13
Move comment up
sethaxen Jan 20, 2021
5ba193d
Move comment further up
sethaxen Jan 20, 2021
49df929
Update docstrings
sethaxen Jan 20, 2021
8c27276
Push header to same level as rules
sethaxen Jan 20, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 68 additions & 50 deletions src/rulesets/LinearAlgebra/matfun.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,45 +86,63 @@ end
## Destructive matrix exponential using algorithm from Higham, 2008,
## "Functions of Matrices: Theory and Computation", SIAM
## Adapted from LinearAlgebra.exp! with return of intermediates
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
function _matfun!(::typeof(exp), A::StridedMatrix{T}) where T<:BlasFloat
function _matfun!(::typeof(exp), A::StridedMatrix{T}) where {T<:BlasFloat}
n = LinearAlgebra.checksquare(A)
ilo, ihi, scale = LAPACK.gebal!('B', A) # modifies A
nA = opnorm(A, 1)
Inn = Matrix{T}(I, n, n)
nA = opnorm(A, 1)
Inn = Matrix{T}(I, n, n)
## For sufficiently small nA, use lower order Padé-Approximations
if (nA <= 2.1)
if nA > 0.95
C = T[17643225600.,8821612800.,2075673600.,302702400.,
30270240., 2162160., 110880., 3960.,
90., 1.]
C = T[
17643225600.0,
8821612800.0,
2075673600.0,
302702400.0,
30270240.0,
2162160.0,
110880.0,
3960.0,
90.0,
1.0,
]
elseif nA > 0.25
C = T[17297280.,8648640.,1995840.,277200.,
25200., 1512., 56., 1.]
C = T[17297280.0, 8648640.0, 1995840.0, 277200.0, 25200.0, 1512.0, 56.0, 1.0]
elseif nA > 0.015
C = T[30240.,15120.,3360.,
420., 30., 1.]
C = T[30240.0, 15120.0, 3360.0, 420.0, 30.0, 1.0]
else
C = T[120.,60.,12.,1.]
C = T[120.0, 60.0, 12.0, 1.0]
end
si = 0
else
C = T[64764752532480000.,32382376266240000.,7771770303897600.,
1187353796428800., 129060195264000., 10559470521600.,
670442572800., 33522128640., 1323241920.,
40840800., 960960., 16380.,
182., 1.]
s = log2(nA/5.4) # power of 2 later reversed by squaring
si = ceil(Int,s)
C = T[
64764752532480000.0,
32382376266240000.0,
7771770303897600.0,
1187353796428800.0,
129060195264000.0,
10559470521600.0,
670442572800.0,
33522128640.0,
1323241920.0,
40840800.0,
960960.0,
16380.0,
182.0,
1.0,
]
s = log2(nA / 5.4) # power of 2 later reversed by squaring
si = ceil(Int, s)
end

if si > 0
A ./= convert(T,2^si)
A ./= convert(T, 2^si)
end

A2 = A * A
P = copy(Inn)
W = C[2] * P
V = C[1] * P
P = copy(Inn)
W = C[2] * P
V = C[1] * P
Apows = typeof(P)[]
for k in 1:(div(size(C, 1), 2) - 1)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
k2 = 2 * k
Expand All @@ -135,32 +153,36 @@ function _matfun!(::typeof(exp), A::StridedMatrix{T}) where T<:BlasFloat
end
U = A * W
X = V + U
F = lu!(V-U) # NOTE: use lu! instead of LAPACK.gesv! so we can reuse factorization
F = lu!(V - U) # NOTE: use lu! instead of LAPACK.gesv! so we can reuse factorization
ldiv!(F, X)
Xpows = typeof(X)[X]
if si > 0 # squaring to reverse dividing by power of 2
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
for t=1:si
for t in 1:si
X *= X
push!(Xpows, X)
end
end

# Undo the balancing
for j = ilo:ihi
for j in ilo:ihi
scj = scale[j]
for i = 1:n
X[j,i] *= scj
for i in 1:n
X[j, i] *= scj
end
for i = 1:n
X[i,j] /= scj
for i in 1:n
X[i, j] /= scj
end
end

if ilo > 1 # apply lower permutations in reverse order
for j in (ilo-1):-1:1; LinearAlgebra.rcswap!(j, Int(scale[j]), X) end
for j in (ilo - 1):-1:1
LinearAlgebra.rcswap!(j, Int(scale[j]), X)
end
end
if ihi < n # apply upper permutations in forward order
for j in (ihi+1):n; LinearAlgebra.rcswap!(j, Int(scale[j]), X) end
for j in (ihi + 1):n
LinearAlgebra.rcswap!(j, Int(scale[j]), X)
end
end
return X, (ilo, ihi, scale, C, si, Apows, W, F, Xpows)
end
Expand All @@ -171,20 +193,16 @@ end
# Condition Number Estimation", SIAM. 30 (4). pp. 1639-1657.
# http://eprints.maths.manchester.ac.uk/id/eprint/1218
function _matfun_frechet!(
::typeof(exp),
A::StridedMatrix{T},
X,
ΔA,
(ilo, ihi, scale, C, si, Apows, W, F, Xpows),
::typeof(exp), A::StridedMatrix{T}, X, ΔA, (ilo, ihi, scale, C, si, Apows, W, F, Xpows)
) where {T<:BlasFloat}
n = LinearAlgebra.checksquare(A)
for j = ilo:ihi
for j in ilo:ihi
scj = scale[j]
for i = 1:n
ΔA[j,i] /= scj
for i in 1:n
ΔA[j, i] /= scj
end
for i = 1:n
ΔA[i,j] *= scj
for i in 1:n
ΔA[i, j] *= scj
end
end

Expand All @@ -199,7 +217,7 @@ function _matfun_frechet!(
∂P = copy(∂A2)
∂W = C[4] * ∂P
∂V = C[3] * ∂P
for k in 2:(length(Apows)-1)
for k in 2:(length(Apows) - 1)
k2 = 2 * k
P = Apows[k - 1]
∂P, ∂temp = mul!(mul!(∂temp, ∂P, A2), P, ∂A2, true, true), ∂P
Expand All @@ -213,29 +231,29 @@ function _matfun_frechet!(
ldiv!(F, ∂X)

if si > 0
for t = 1:(length(Xpows)-1)
for t in 1:(length(Xpows) - 1)
X = Xpows[t]
∂X, ∂temp = mul!(mul!(∂temp, X, ∂X), ∂X, X, true, true), ∂X
end
end

for j = ilo:ihi
for j in ilo:ihi
scj = scale[j]
for i = 1:n
∂X[j,i] *= scj
for i in 1:n
∂X[j, i] *= scj
end
for i = 1:n
∂X[i,j] /= scj
for i in 1:n
∂X[i, j] /= scj
end
end

if ilo > 1 # apply lower permutations in reverse order
for j in (ilo-1):-1:1
for j in (ilo - 1):-1:1
LinearAlgebra.rcswap!(j, Int(scale[j]), ∂X)
end
end
if ihi < n # apply upper permutations in forward order
for j in (ihi+1):n
for j in (ihi + 1):n
LinearAlgebra.rcswap!(j, Int(scale[j]), ∂X)
end
end
Expand Down
14 changes: 9 additions & 5 deletions test/rulesets/LinearAlgebra/matfun.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
@testset "matrix functions" begin
@testset "LinearAlgebra.exp!(A::Matrix) frule" begin
n = 10
@testset "A::Matrix{$T}, opnorm(A,1)=$nrm" for T in (Float64, ComplexF64), nrm in (0.01, 0.1, 0.5, 1.5, 3.0, 6.0, 12.0)
@testset "A::Matrix{$T}, opnorm(A,1)=$nrm" for T in (Float64, ComplexF64),
nrm in (0.01, 0.1, 0.5, 1.5, 3.0, 6.0, 12.0)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved

A, ΔA = randn(ComplexF64, n, n), randn(ComplexF64, n, n)
# choose normalization to hit specific branch
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
A *= nrm / opnorm(A, 1)
Expand All @@ -16,21 +18,23 @@

@testset "exp(A::Matrix) rrule" begin
n = 10
@testset "A::Matrix{$T}, opnorm(A,1)=$nrm" for T in (Float64, ComplexF64), nrm in (0.01, 0.1, 0.5, 1.5, 3.0, 6.0, 12.0)
@testset "A::Matrix{$T}, opnorm(A,1)=$nrm" for T in (Float64, ComplexF64),
nrm in (0.01, 0.1, 0.5, 1.5, 3.0, 6.0, 12.0)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved

A, ΔA = randn(ComplexF64, n, n), randn(ComplexF64, n, n)
ΔY = randn(ComplexF64, n, n)
# choose normalization to hit specific branch
A *= nrm / opnorm(A, 1)
# rrule is not inferrable, but pullback should be
rrule_test(exp, ΔY, (A, ΔA); check_inferred = false)
rrule_test(exp, ΔY, (A, ΔA); check_inferred=false)
Y, back = rrule(exp, A)
@inferred back(ΔY)
end
@testset "hermitian A" begin
A, ΔA = Matrix(Hermitian(randn(ComplexF64, n, n))), randn(ComplexF64, n, n)
ΔY = randn(ComplexF64, n, n)
rrule_test(exp, Matrix(Hermitian(ΔY)), (A, ΔA); check_inferred = false)
rrule_test(exp, ΔY, (A, ΔA); check_inferred = false)
rrule_test(exp, Matrix(Hermitian(ΔY)), (A, ΔA); check_inferred=false)
rrule_test(exp, ΔY, (A, ΔA); check_inferred=false)
end
end
end