From 6e8358cd92de0b163dbcedeed4febec9fd93e1cd Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 18 Jan 2021 02:44:56 -0800 Subject: [PATCH 01/33] Add matfun.jl file --- src/ChainRules.jl | 1 + src/rulesets/LinearAlgebra/matfun.jl | 1 + 2 files changed, 2 insertions(+) create mode 100644 src/rulesets/LinearAlgebra/matfun.jl diff --git a/src/ChainRules.jl b/src/ChainRules.jl index a631e5845..a8684c8a6 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -44,6 +44,7 @@ include("rulesets/LinearAlgebra/utils.jl") include("rulesets/LinearAlgebra/blas.jl") include("rulesets/LinearAlgebra/dense.jl") include("rulesets/LinearAlgebra/norm.jl") +include("rulesets/LinearAlgebra/matfun.jl") include("rulesets/LinearAlgebra/structured.jl") include("rulesets/LinearAlgebra/symmetric.jl") include("rulesets/LinearAlgebra/factorization.jl") diff --git a/src/rulesets/LinearAlgebra/matfun.jl b/src/rulesets/LinearAlgebra/matfun.jl new file mode 100644 index 000000000..f5b6c0561 --- /dev/null +++ b/src/rulesets/LinearAlgebra/matfun.jl @@ -0,0 +1 @@ +# matrix functions of dense matrices From f987f3599773ea97c04734370065bb99d1bc55e4 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 18 Jan 2021 02:45:18 -0800 Subject: [PATCH 02/33] Add matfun docstrings --- src/rulesets/LinearAlgebra/matfun.jl | 33 +++++++++++++++++++++++++ src/rulesets/LinearAlgebra/symmetric.jl | 6 ----- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/src/rulesets/LinearAlgebra/matfun.jl b/src/rulesets/LinearAlgebra/matfun.jl index f5b6c0561..7189e1982 100644 --- a/src/rulesets/LinearAlgebra/matfun.jl +++ b/src/rulesets/LinearAlgebra/matfun.jl @@ -1 +1,34 @@ # matrix functions of dense matrices + +""" + _matfun(f, A) -> (Y, intermediates) + +Compute the matrix function `Y=f(A)` for matrix `A`. +The function returns a tuple containing the result and a tuple of intermediates to be +reused by `_matfun_frechet` to compute the Fréchet derivative. +Note that any function `f` used with this **must** have a `frule` defined on it. +""" +_matfun + +""" + _matfun!(f, A) -> (Y, intermediates) + +Similar to [`_matfun`](@ref), but where `A` may be overwritten. +""" +_matfun! + +""" + _matfun_frechet(f, A, Y, ΔA, intermediates) + +Compute the Fréchet derivative of the matrix function `Y=f(A)`, where the Fréchet derivative +of `A` is `ΔA`, and `intermediates` is the second argument returned by `_matfun`. +""" +_matfun_frechet + +""" + _matfun_frechet!(f, A, Y, ΔA, intermediates) + +Similar to `_matfun_frechet!`, but where `ΔA` may be overwritten. +""" +_matfun_frechet! + diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index a1985ec40..a9d6d3e2b 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -392,12 +392,6 @@ function _matfun(f, A::LinearAlgebra.RealHermSymComplexHerm) end # Computes ∂Y = U * (P .* (U' * ΔA * U)) * U' with fewer allocations -""" - _matfun_frechet(f, A::RealHermSymComplexHerm, Y, ΔA, intermediates) - -Compute the Fréchet derivative of the matrix function `Y=f(A)`, where the Fréchet derivative -of `A` is `ΔA`, and `intermediates` is the second argument returned by `_matfun`. -""" function _matfun_frechet(f, A::LinearAlgebra.RealHermSymComplexHerm, Y, ΔA, (λ, U, fλ, df_dλ)) # We will overwrite tmp matrix several times to hold different values tmp = mul!(similar(U, Base.promote_eltype(U, ΔA)), ΔA, U) From e6b92c9788fa8e2eaf3194cf6ffd17fff8d8510b Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 18 Jan 2021 02:45:51 -0800 Subject: [PATCH 03/33] Add exp matrix function --- src/rulesets/LinearAlgebra/matfun.jl | 196 +++++++++++++++++++++++++++ 1 file changed, 196 insertions(+) diff --git a/src/rulesets/LinearAlgebra/matfun.jl b/src/rulesets/LinearAlgebra/matfun.jl index 7189e1982..5852dec17 100644 --- a/src/rulesets/LinearAlgebra/matfun.jl +++ b/src/rulesets/LinearAlgebra/matfun.jl @@ -32,3 +32,199 @@ Similar to `_matfun_frechet!`, but where `ΔA` may be overwritten. """ _matfun_frechet! +function frule((_, ΔA), ::typeof(LinearAlgebra.exp!), A::StridedMatrix{<:BlasFloat}) + if ishermitian(A) + hermX, ∂hermX = frule((Zero(), ΔA), exp, Hermitian(A)) + X = LinearAlgebra.copytri!(parent(hermX), 'U', true) + if ∂hermX isa LinearAlgebra.RealHermSymComplexHerm + ∂X = LinearAlgebra.copytri!(parent(∂hermX), 'U', true) + else + ∂X = ∂hermX + end + else + X, intermediates = _matfun!(exp, A) + ∂X = _matfun_frechet!(exp, A, X, ΔA, intermediates) + end + return X, ∂X +end + +function rrule(::typeof(exp), A0::StridedMatrix{<:BlasFloat}) + # TODO: try to make this more type-stable + if ishermitian(A0) + # call _matfun instead of the rrule to avoid hermitrizing ∂A in the pullback + hermA = Hermitian(A0) + hermX, intermediates = _matfun(exp, hermA) + function exp_pullback_hermitian(ΔX) + ∂hermA = _matfun_frechet(exp, hermA, hermX, ΔX, intermediates) + ∂A = ∂hermA isa LinearAlgebra.RealHermSymComplexHerm ? parent(∂hermA) : ∂hermA + return NO_FIELDS, ∂A + end + return LinearAlgebra.copytri!(parent(hermX), 'U', true), exp_pullback_hermitian + else + A = copy(A0) + X, intermediates = _matfun!(exp, A) + function exp_pullback(ΔX) + ΔX′ = copy(adjoint(ΔX)) + ∂A′ = _matfun_frechet!(exp, A, X, ΔX′, intermediates) + ∂A = copy(adjoint(∂A′)) + return NO_FIELDS, ∂A + end + return X, exp_pullback + end +end + +## Destructive matrix exponential using algorithm from Higham, 2008, +## "Functions of Matrices: Theory and Computation", SIAM +## Adapted from LinearAlgebra.exp! with return of intermediates +function _matfun!(::typeof(exp), A::StridedMatrix{T}) where T<:BlasFloat + n = LinearAlgebra.checksquare(A) + ilo, ihi, scale = LAPACK.gebal!('B', A) # modifies A + nA = opnorm(A, 1) + Inn = Matrix{T}(I, n, n) + ## For sufficiently small nA, use lower order Padé-Approximations + if (nA <= 2.1) + if nA > 0.95 + C = T[17643225600.,8821612800.,2075673600.,302702400., + 30270240., 2162160., 110880., 3960., + 90., 1.] + elseif nA > 0.25 + C = T[17297280.,8648640.,1995840.,277200., + 25200., 1512., 56., 1.] + elseif nA > 0.015 + C = T[30240.,15120.,3360., + 420., 30., 1.] + else + C = T[120.,60.,12.,1.] + end + si = 0 + else + C = T[64764752532480000.,32382376266240000.,7771770303897600., + 1187353796428800., 129060195264000., 10559470521600., + 670442572800., 33522128640., 1323241920., + 40840800., 960960., 16380., + 182., 1.] + s = log2(nA/5.4) # power of 2 later reversed by squaring + si = ceil(Int,s) + end + + if si > 0 + A ./= convert(T,2^si) + end + + A2 = A * A + P = copy(Inn) + W = C[2] * P + V = C[1] * P + Apows = typeof(P)[] + for k in 1:(div(size(C, 1), 2) - 1) + k2 = 2 * k + P *= A2 + push!(Apows, P) + W += C[k2 + 2] * P + V += C[k2 + 1] * P + end + pop!(Apows) + U = A * W + X = V + U + F = lu!(V-U) # NOTE: use lu! instead of LAPACK.gesv! so we can reuse factorization + ldiv!(F, X) + Xpows = typeof(X)[X] + if si > 0 # squaring to reverse dividing by power of 2 + for t=1:si + X *= X + push!(Xpows, X) + end + pop!(Xpows) + end + + # Undo the balancing + for j = ilo:ihi + scj = scale[j] + for i = 1:n + X[j,i] *= scj + end + for i = 1:n + X[i,j] /= scj + end + end + + if ilo > 1 # apply lower permutations in reverse order + for j in (ilo-1):-1:1; LinearAlgebra.rcswap!(j, Int(scale[j]), X) end + end + if ihi < n # apply upper permutations in forward order + for j in (ihi+1):n; LinearAlgebra.rcswap!(j, Int(scale[j]), X) end + end + return X, (ilo, ihi, scale, C, si, Apows, W, F, Xpows) +end + +function _matfun_frechet!( + ::typeof(exp), + A::StridedMatrix{T}, + X, + ΔA, + (ilo, ihi, scale, C, si, Apows, W, F, Xpows), +) where {T<:BlasFloat} + n = LinearAlgebra.checksquare(A) + for j = ilo:ihi + scj = scale[j] + for i = 1:n + ΔA[j,i] /= scj + end + for i = 1:n + ΔA[i,j] *= scj + end + end + + if si > 0 + ΔA ./= convert(T, 2^si) + end + + ∂A2 = mul!(A * ΔA, ΔA, A, true, true) + A2 = first(Apows) + # we will repeatedly overwrite ∂temp and ∂P below + ∂temp = Matrix{eltype(∂A2)}(undef, n, n) + ∂P = copy(∂A2) + ∂W = C[4] * ∂P + ∂V = C[3] * ∂P + for k in 2:length(Apows) + k2 = 2 * k + P = Apows[k - 1] + ∂P, ∂temp = mul!(mul!(∂temp, ∂P, A2), P, ∂A2, true, true), ∂P + axpy!(C[k2 + 2], ∂P, ∂W) + axpy!(C[k2 + 1], ∂P, ∂V) + end + ∂U, ∂temp = mul!(mul!(∂temp, A, ∂W), ΔA, W, true, true), ∂W + ∂temp .= ∂U .- ∂V + ∂X = add!!(∂U, ∂V) + mul!(∂X, ∂temp, first(Xpows), true, true) + ldiv!(F, ∂X) + + if si > 0 + for t = eachindex(Xpows) + X = Xpows[t] + ∂X, ∂temp = mul!(mul!(∂temp, X, ∂X), ∂X, X, true, true), ∂X + end + end + + for j = ilo:ihi + scj = scale[j] + for i = 1:n + ∂X[j,i] *= scj + end + for i = 1:n + ∂X[i,j] /= scj + end + end + + if ilo > 1 # apply lower permutations in reverse order + for j in (ilo-1):-1:1 + LinearAlgebra.rcswap!(j, Int(scale[j]), ∂X) + end + end + if ihi < n # apply upper permutations in forward order + for j in (ihi+1):n + LinearAlgebra.rcswap!(j, Int(scale[j]), ∂X) + end + end + return ∂X +end From 7624ee5f74f0b0544ebbeb42c2bf6d466ab2044e Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 18 Jan 2021 03:02:45 -0800 Subject: [PATCH 04/33] At least store one intermediate --- src/rulesets/LinearAlgebra/matfun.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/rulesets/LinearAlgebra/matfun.jl b/src/rulesets/LinearAlgebra/matfun.jl index 5852dec17..2229a0db4 100644 --- a/src/rulesets/LinearAlgebra/matfun.jl +++ b/src/rulesets/LinearAlgebra/matfun.jl @@ -123,7 +123,6 @@ function _matfun!(::typeof(exp), A::StridedMatrix{T}) where T<:BlasFloat W += C[k2 + 2] * P V += C[k2 + 1] * P end - pop!(Apows) U = A * W X = V + U F = lu!(V-U) # NOTE: use lu! instead of LAPACK.gesv! so we can reuse factorization @@ -134,7 +133,6 @@ function _matfun!(::typeof(exp), A::StridedMatrix{T}) where T<:BlasFloat X *= X push!(Xpows, X) end - pop!(Xpows) end # Undo the balancing @@ -186,7 +184,7 @@ function _matfun_frechet!( ∂P = copy(∂A2) ∂W = C[4] * ∂P ∂V = C[3] * ∂P - for k in 2:length(Apows) + for k in 2:(length(Apows)-1) k2 = 2 * k P = Apows[k - 1] ∂P, ∂temp = mul!(mul!(∂temp, ∂P, A2), P, ∂A2, true, true), ∂P @@ -200,7 +198,7 @@ function _matfun_frechet!( ldiv!(F, ∂X) if si > 0 - for t = eachindex(Xpows) + for t = 1:(length(Xpows)-1) X = Xpows[t] ∂X, ∂temp = mul!(mul!(∂temp, X, ∂X), ∂X, X, true, true), ∂X end From a7792a5cb51a2693853aa4ec5942f046b8403de3 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 18 Jan 2021 03:03:16 -0800 Subject: [PATCH 05/33] Test exp! --- test/rulesets/LinearAlgebra/matfun.jl | 17 +++++++++++++++++ test/runtests.jl | 1 + 2 files changed, 18 insertions(+) create mode 100644 test/rulesets/LinearAlgebra/matfun.jl diff --git a/test/rulesets/LinearAlgebra/matfun.jl b/test/rulesets/LinearAlgebra/matfun.jl new file mode 100644 index 000000000..127ac49ff --- /dev/null +++ b/test/rulesets/LinearAlgebra/matfun.jl @@ -0,0 +1,17 @@ +@testset "matrix functions" begin + @testset "LinearAlgebra.exp!(A) frule" begin + n = 10 + @testset "A::Matrix{$T}, opnorm(A,1)=$nrm" for T in (Float64, ComplexF64), nrm in (0.01, 0.1, 0.5, 1.5, 3.0, 6.0, 12.0) + A, ΔA = randn(ComplexF64, n, n), randn(ComplexF64, n, n) + # choose normalization to hit specific branch + A *= nrm / opnorm(A, 1) + frule_test(LinearAlgebra.exp!, (A, ΔA)) + end + @testset "hermitian A" begin + A, ΔA = randn(ComplexF64, n, n), randn(ComplexF64, n, n) + A = Matrix(Hermitian(A)) + frule_test(LinearAlgebra.exp!, (A, Matrix(Hermitian(ΔA)))) + frule_test(LinearAlgebra.exp!, (A, ΔA)) + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 7362d68a1..ea6eb5d8b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -43,6 +43,7 @@ println("Testing ChainRules.jl") @testset "LinearAlgebra" begin include_test("rulesets/LinearAlgebra/dense.jl") include_test("rulesets/LinearAlgebra/norm.jl") + include_test("rulesets/LinearAlgebra/matfun.jl") include_test("rulesets/LinearAlgebra/structured.jl") include_test("rulesets/LinearAlgebra/symmetric.jl") include_test("rulesets/LinearAlgebra/factorization.jl") From 6d6b4cb142099ec45d796cd1d73393f475cb76bf Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 18 Jan 2021 03:13:42 -0800 Subject: [PATCH 06/33] Make pullback type-inferrable --- src/rulesets/LinearAlgebra/matfun.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/rulesets/LinearAlgebra/matfun.jl b/src/rulesets/LinearAlgebra/matfun.jl index 2229a0db4..acba6cdb7 100644 --- a/src/rulesets/LinearAlgebra/matfun.jl +++ b/src/rulesets/LinearAlgebra/matfun.jl @@ -53,11 +53,11 @@ function rrule(::typeof(exp), A0::StridedMatrix{<:BlasFloat}) if ishermitian(A0) # call _matfun instead of the rrule to avoid hermitrizing ∂A in the pullback hermA = Hermitian(A0) - hermX, intermediates = _matfun(exp, hermA) + hermX, hermX_intermediates = _matfun(exp, hermA) function exp_pullback_hermitian(ΔX) - ∂hermA = _matfun_frechet(exp, hermA, hermX, ΔX, intermediates) - ∂A = ∂hermA isa LinearAlgebra.RealHermSymComplexHerm ? parent(∂hermA) : ∂hermA - return NO_FIELDS, ∂A + ∂hermA = _matfun_frechet(exp, hermA, hermX, ΔX, hermX_intermediates) + ∂hermA isa LinearAlgebra.RealHermSymComplexHerm || return NO_FIELDS, ∂hermA + return NO_FIELDS, parent(∂hermA) end return LinearAlgebra.copytri!(parent(hermX), 'U', true), exp_pullback_hermitian else From 3645d7540d7dd6eb92136bfbeae3e935b4e0381e Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 18 Jan 2021 03:14:07 -0800 Subject: [PATCH 07/33] Add clearer test label --- test/rulesets/LinearAlgebra/matfun.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rulesets/LinearAlgebra/matfun.jl b/test/rulesets/LinearAlgebra/matfun.jl index 127ac49ff..44956cd5d 100644 --- a/test/rulesets/LinearAlgebra/matfun.jl +++ b/test/rulesets/LinearAlgebra/matfun.jl @@ -1,5 +1,5 @@ @testset "matrix functions" begin - @testset "LinearAlgebra.exp!(A) frule" begin + @testset "LinearAlgebra.exp!(A::Matrix) frule" begin n = 10 @testset "A::Matrix{$T}, opnorm(A,1)=$nrm" for T in (Float64, ComplexF64), nrm in (0.01, 0.1, 0.5, 1.5, 3.0, 6.0, 12.0) A, ΔA = randn(ComplexF64, n, n), randn(ComplexF64, n, n) From 937f2acc5e256302df3d5c371807c47c418615ac Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 18 Jan 2021 03:14:21 -0800 Subject: [PATCH 08/33] Create as hermitian --- test/rulesets/LinearAlgebra/matfun.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/rulesets/LinearAlgebra/matfun.jl b/test/rulesets/LinearAlgebra/matfun.jl index 44956cd5d..aed89a551 100644 --- a/test/rulesets/LinearAlgebra/matfun.jl +++ b/test/rulesets/LinearAlgebra/matfun.jl @@ -8,8 +8,7 @@ frule_test(LinearAlgebra.exp!, (A, ΔA)) end @testset "hermitian A" begin - A, ΔA = randn(ComplexF64, n, n), randn(ComplexF64, n, n) - A = Matrix(Hermitian(A)) + A, ΔA = Matrix(Hermitian(randn(ComplexF64, n, n))), randn(ComplexF64, n, n) frule_test(LinearAlgebra.exp!, (A, Matrix(Hermitian(ΔA)))) frule_test(LinearAlgebra.exp!, (A, ΔA)) end From b48204c60552b48be11da3faa9ed0f92f64bda53 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 18 Jan 2021 03:14:33 -0800 Subject: [PATCH 09/33] Test rrule --- test/rulesets/LinearAlgebra/matfun.jl | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/test/rulesets/LinearAlgebra/matfun.jl b/test/rulesets/LinearAlgebra/matfun.jl index aed89a551..12ee2b9ef 100644 --- a/test/rulesets/LinearAlgebra/matfun.jl +++ b/test/rulesets/LinearAlgebra/matfun.jl @@ -13,4 +13,24 @@ frule_test(LinearAlgebra.exp!, (A, ΔA)) end end + + @testset "exp(A::Matrix) rrule" begin + n = 10 + @testset "A::Matrix{$T}, opnorm(A,1)=$nrm" for T in (Float64, ComplexF64), nrm in (0.01, 0.1, 0.5, 1.5, 3.0, 6.0, 12.0) + A, ΔA = randn(ComplexF64, n, n), randn(ComplexF64, n, n) + ΔY = randn(ComplexF64, n, n) + # choose normalization to hit specific branch + A *= nrm / opnorm(A, 1) + # rrule is not inferrable, but pullback should be + rrule_test(exp, ΔY, (A, ΔA); check_inferred = false) + Y, back = rrule(exp, A) + @inferred back(ΔY) + end + @testset "hermitian A" begin + A, ΔA = Matrix(Hermitian(randn(ComplexF64, n, n))), randn(ComplexF64, n, n) + ΔY = randn(ComplexF64, n, n) + rrule_test(exp, Matrix(Hermitian(ΔY)), (A, ΔA); check_inferred = false) + rrule_test(exp, ΔY, (A, ΔA); check_inferred = false) + end + end end From 2c19bba68599c0c65da3a3ccc635209fd99094f9 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 18 Jan 2021 03:21:45 -0800 Subject: [PATCH 10/33] Add comment about relationship between pushforward and pullback --- src/rulesets/LinearAlgebra/matfun.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/rulesets/LinearAlgebra/matfun.jl b/src/rulesets/LinearAlgebra/matfun.jl index acba6cdb7..cc873a3ec 100644 --- a/src/rulesets/LinearAlgebra/matfun.jl +++ b/src/rulesets/LinearAlgebra/matfun.jl @@ -1,5 +1,11 @@ # matrix functions of dense matrices +# NOTE: for matrix functions whose power series representation has real coefficients, +# the pullback and pushforward are related by an adjoint. +# Specifically, if the pushforward of f(A) is (f_*)_A(ΔA), then the pullback at Y=f(A) is +# (f^*)_Y(ΔY) = (f_*)_{A'}(ΔY) = ((f_*)_A(ΔY'))' +# So we reuse the code from the pushforward to implement the pullback. + """ _matfun(f, A) -> (Y, intermediates) From 58f60050c24b53299400b6b84c12d196aca38e74 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 18 Jan 2021 03:22:36 -0800 Subject: [PATCH 11/33] Add header --- src/rulesets/LinearAlgebra/matfun.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/rulesets/LinearAlgebra/matfun.jl b/src/rulesets/LinearAlgebra/matfun.jl index cc873a3ec..1840234d3 100644 --- a/src/rulesets/LinearAlgebra/matfun.jl +++ b/src/rulesets/LinearAlgebra/matfun.jl @@ -38,6 +38,10 @@ Similar to `_matfun_frechet!`, but where `ΔA` may be overwritten. """ _matfun_frechet! +##### +##### `exp`/`exp!` +##### + function frule((_, ΔA), ::typeof(LinearAlgebra.exp!), A::StridedMatrix{<:BlasFloat}) if ishermitian(A) hermX, ∂hermX = frule((Zero(), ΔA), exp, Hermitian(A)) From 6ee17595ac4ad9f28c248b61d83e9b2fa4bef203 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 18 Jan 2021 03:26:25 -0800 Subject: [PATCH 12/33] Add reference to Frechet deriv paper --- src/rulesets/LinearAlgebra/matfun.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/rulesets/LinearAlgebra/matfun.jl b/src/rulesets/LinearAlgebra/matfun.jl index 1840234d3..179c476bf 100644 --- a/src/rulesets/LinearAlgebra/matfun.jl +++ b/src/rulesets/LinearAlgebra/matfun.jl @@ -165,6 +165,11 @@ function _matfun!(::typeof(exp), A::StridedMatrix{T}) where T<:BlasFloat return X, (ilo, ihi, scale, C, si, Apows, W, F, Xpows) end +# Application of the chain rule to exp!, also Algorithm 6.4 from +# Al-Mohy, Awad H. and Higham, Nicholas J. (2009). +# Computing the Fréchet Derivative of the Matrix Exponential, with an application to +# Condition Number Estimation", SIAM. 30 (4). pp. 1639-1657. +# http://eprints.maths.manchester.ac.uk/id/eprint/1218 function _matfun_frechet!( ::typeof(exp), A::StridedMatrix{T}, From b1a29802e8ee65a81de79f17d93e6e8779e88c77 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 18 Jan 2021 11:31:30 -0800 Subject: [PATCH 13/33] Run JuliaFormatter --- src/rulesets/LinearAlgebra/matfun.jl | 118 +++++++++++++++----------- test/rulesets/LinearAlgebra/matfun.jl | 14 +-- 2 files changed, 77 insertions(+), 55 deletions(-) diff --git a/src/rulesets/LinearAlgebra/matfun.jl b/src/rulesets/LinearAlgebra/matfun.jl index 179c476bf..3fadf9b9f 100644 --- a/src/rulesets/LinearAlgebra/matfun.jl +++ b/src/rulesets/LinearAlgebra/matfun.jl @@ -86,45 +86,63 @@ end ## Destructive matrix exponential using algorithm from Higham, 2008, ## "Functions of Matrices: Theory and Computation", SIAM ## Adapted from LinearAlgebra.exp! with return of intermediates -function _matfun!(::typeof(exp), A::StridedMatrix{T}) where T<:BlasFloat +function _matfun!(::typeof(exp), A::StridedMatrix{T}) where {T<:BlasFloat} n = LinearAlgebra.checksquare(A) ilo, ihi, scale = LAPACK.gebal!('B', A) # modifies A - nA = opnorm(A, 1) - Inn = Matrix{T}(I, n, n) + nA = opnorm(A, 1) + Inn = Matrix{T}(I, n, n) ## For sufficiently small nA, use lower order Padé-Approximations if (nA <= 2.1) if nA > 0.95 - C = T[17643225600.,8821612800.,2075673600.,302702400., - 30270240., 2162160., 110880., 3960., - 90., 1.] + C = T[ + 17643225600.0, + 8821612800.0, + 2075673600.0, + 302702400.0, + 30270240.0, + 2162160.0, + 110880.0, + 3960.0, + 90.0, + 1.0, + ] elseif nA > 0.25 - C = T[17297280.,8648640.,1995840.,277200., - 25200., 1512., 56., 1.] + C = T[17297280.0, 8648640.0, 1995840.0, 277200.0, 25200.0, 1512.0, 56.0, 1.0] elseif nA > 0.015 - C = T[30240.,15120.,3360., - 420., 30., 1.] + C = T[30240.0, 15120.0, 3360.0, 420.0, 30.0, 1.0] else - C = T[120.,60.,12.,1.] + C = T[120.0, 60.0, 12.0, 1.0] end si = 0 else - C = T[64764752532480000.,32382376266240000.,7771770303897600., - 1187353796428800., 129060195264000., 10559470521600., - 670442572800., 33522128640., 1323241920., - 40840800., 960960., 16380., - 182., 1.] - s = log2(nA/5.4) # power of 2 later reversed by squaring - si = ceil(Int,s) + C = T[ + 64764752532480000.0, + 32382376266240000.0, + 7771770303897600.0, + 1187353796428800.0, + 129060195264000.0, + 10559470521600.0, + 670442572800.0, + 33522128640.0, + 1323241920.0, + 40840800.0, + 960960.0, + 16380.0, + 182.0, + 1.0, + ] + s = log2(nA / 5.4) # power of 2 later reversed by squaring + si = ceil(Int, s) end if si > 0 - A ./= convert(T,2^si) + A ./= convert(T, 2^si) end A2 = A * A - P = copy(Inn) - W = C[2] * P - V = C[1] * P + P = copy(Inn) + W = C[2] * P + V = C[1] * P Apows = typeof(P)[] for k in 1:(div(size(C, 1), 2) - 1) k2 = 2 * k @@ -135,32 +153,36 @@ function _matfun!(::typeof(exp), A::StridedMatrix{T}) where T<:BlasFloat end U = A * W X = V + U - F = lu!(V-U) # NOTE: use lu! instead of LAPACK.gesv! so we can reuse factorization + F = lu!(V - U) # NOTE: use lu! instead of LAPACK.gesv! so we can reuse factorization ldiv!(F, X) Xpows = typeof(X)[X] if si > 0 # squaring to reverse dividing by power of 2 - for t=1:si + for t in 1:si X *= X push!(Xpows, X) end end # Undo the balancing - for j = ilo:ihi + for j in ilo:ihi scj = scale[j] - for i = 1:n - X[j,i] *= scj + for i in 1:n + X[j, i] *= scj end - for i = 1:n - X[i,j] /= scj + for i in 1:n + X[i, j] /= scj end end if ilo > 1 # apply lower permutations in reverse order - for j in (ilo-1):-1:1; LinearAlgebra.rcswap!(j, Int(scale[j]), X) end + for j in (ilo - 1):-1:1 + LinearAlgebra.rcswap!(j, Int(scale[j]), X) + end end if ihi < n # apply upper permutations in forward order - for j in (ihi+1):n; LinearAlgebra.rcswap!(j, Int(scale[j]), X) end + for j in (ihi + 1):n + LinearAlgebra.rcswap!(j, Int(scale[j]), X) + end end return X, (ilo, ihi, scale, C, si, Apows, W, F, Xpows) end @@ -171,20 +193,16 @@ end # Condition Number Estimation", SIAM. 30 (4). pp. 1639-1657. # http://eprints.maths.manchester.ac.uk/id/eprint/1218 function _matfun_frechet!( - ::typeof(exp), - A::StridedMatrix{T}, - X, - ΔA, - (ilo, ihi, scale, C, si, Apows, W, F, Xpows), + ::typeof(exp), A::StridedMatrix{T}, X, ΔA, (ilo, ihi, scale, C, si, Apows, W, F, Xpows) ) where {T<:BlasFloat} n = LinearAlgebra.checksquare(A) - for j = ilo:ihi + for j in ilo:ihi scj = scale[j] - for i = 1:n - ΔA[j,i] /= scj + for i in 1:n + ΔA[j, i] /= scj end - for i = 1:n - ΔA[i,j] *= scj + for i in 1:n + ΔA[i, j] *= scj end end @@ -199,7 +217,7 @@ function _matfun_frechet!( ∂P = copy(∂A2) ∂W = C[4] * ∂P ∂V = C[3] * ∂P - for k in 2:(length(Apows)-1) + for k in 2:(length(Apows) - 1) k2 = 2 * k P = Apows[k - 1] ∂P, ∂temp = mul!(mul!(∂temp, ∂P, A2), P, ∂A2, true, true), ∂P @@ -213,29 +231,29 @@ function _matfun_frechet!( ldiv!(F, ∂X) if si > 0 - for t = 1:(length(Xpows)-1) + for t in 1:(length(Xpows) - 1) X = Xpows[t] ∂X, ∂temp = mul!(mul!(∂temp, X, ∂X), ∂X, X, true, true), ∂X end end - for j = ilo:ihi + for j in ilo:ihi scj = scale[j] - for i = 1:n - ∂X[j,i] *= scj + for i in 1:n + ∂X[j, i] *= scj end - for i = 1:n - ∂X[i,j] /= scj + for i in 1:n + ∂X[i, j] /= scj end end if ilo > 1 # apply lower permutations in reverse order - for j in (ilo-1):-1:1 + for j in (ilo - 1):-1:1 LinearAlgebra.rcswap!(j, Int(scale[j]), ∂X) end end if ihi < n # apply upper permutations in forward order - for j in (ihi+1):n + for j in (ihi + 1):n LinearAlgebra.rcswap!(j, Int(scale[j]), ∂X) end end diff --git a/test/rulesets/LinearAlgebra/matfun.jl b/test/rulesets/LinearAlgebra/matfun.jl index 12ee2b9ef..ac6fe03f7 100644 --- a/test/rulesets/LinearAlgebra/matfun.jl +++ b/test/rulesets/LinearAlgebra/matfun.jl @@ -1,7 +1,9 @@ @testset "matrix functions" begin @testset "LinearAlgebra.exp!(A::Matrix) frule" begin n = 10 - @testset "A::Matrix{$T}, opnorm(A,1)=$nrm" for T in (Float64, ComplexF64), nrm in (0.01, 0.1, 0.5, 1.5, 3.0, 6.0, 12.0) + @testset "A::Matrix{$T}, opnorm(A,1)=$nrm" for T in (Float64, ComplexF64), +nrm in (0.01, 0.1, 0.5, 1.5, 3.0, 6.0, 12.0) + A, ΔA = randn(ComplexF64, n, n), randn(ComplexF64, n, n) # choose normalization to hit specific branch A *= nrm / opnorm(A, 1) @@ -16,21 +18,23 @@ @testset "exp(A::Matrix) rrule" begin n = 10 - @testset "A::Matrix{$T}, opnorm(A,1)=$nrm" for T in (Float64, ComplexF64), nrm in (0.01, 0.1, 0.5, 1.5, 3.0, 6.0, 12.0) + @testset "A::Matrix{$T}, opnorm(A,1)=$nrm" for T in (Float64, ComplexF64), +nrm in (0.01, 0.1, 0.5, 1.5, 3.0, 6.0, 12.0) + A, ΔA = randn(ComplexF64, n, n), randn(ComplexF64, n, n) ΔY = randn(ComplexF64, n, n) # choose normalization to hit specific branch A *= nrm / opnorm(A, 1) # rrule is not inferrable, but pullback should be - rrule_test(exp, ΔY, (A, ΔA); check_inferred = false) + rrule_test(exp, ΔY, (A, ΔA); check_inferred=false) Y, back = rrule(exp, A) @inferred back(ΔY) end @testset "hermitian A" begin A, ΔA = Matrix(Hermitian(randn(ComplexF64, n, n))), randn(ComplexF64, n, n) ΔY = randn(ComplexF64, n, n) - rrule_test(exp, Matrix(Hermitian(ΔY)), (A, ΔA); check_inferred = false) - rrule_test(exp, ΔY, (A, ΔA); check_inferred = false) + rrule_test(exp, Matrix(Hermitian(ΔY)), (A, ΔA); check_inferred=false) + rrule_test(exp, ΔY, (A, ΔA); check_inferred=false) end end end From e860b3eff0ddc67882bd0ee97a3e1358153cbf71 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 18 Jan 2021 11:36:37 -0800 Subject: [PATCH 14/33] Reduce comment spacing from code --- src/rulesets/LinearAlgebra/matfun.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/rulesets/LinearAlgebra/matfun.jl b/src/rulesets/LinearAlgebra/matfun.jl index 3fadf9b9f..b5ae252cf 100644 --- a/src/rulesets/LinearAlgebra/matfun.jl +++ b/src/rulesets/LinearAlgebra/matfun.jl @@ -88,7 +88,7 @@ end ## Adapted from LinearAlgebra.exp! with return of intermediates function _matfun!(::typeof(exp), A::StridedMatrix{T}) where {T<:BlasFloat} n = LinearAlgebra.checksquare(A) - ilo, ihi, scale = LAPACK.gebal!('B', A) # modifies A + ilo, ihi, scale = LAPACK.gebal!('B', A) # modifies A nA = opnorm(A, 1) Inn = Matrix{T}(I, n, n) ## For sufficiently small nA, use lower order Padé-Approximations @@ -131,7 +131,7 @@ function _matfun!(::typeof(exp), A::StridedMatrix{T}) where {T<:BlasFloat} 182.0, 1.0, ] - s = log2(nA / 5.4) # power of 2 later reversed by squaring + s = log2(nA / 5.4) # power of 2 later reversed by squaring si = ceil(Int, s) end @@ -153,10 +153,10 @@ function _matfun!(::typeof(exp), A::StridedMatrix{T}) where {T<:BlasFloat} end U = A * W X = V + U - F = lu!(V - U) # NOTE: use lu! instead of LAPACK.gesv! so we can reuse factorization + F = lu!(V - U) # NOTE: use lu! instead of LAPACK.gesv! so we can reuse factorization ldiv!(F, X) Xpows = typeof(X)[X] - if si > 0 # squaring to reverse dividing by power of 2 + if si > 0 # squaring to reverse dividing by power of 2 for t in 1:si X *= X push!(Xpows, X) @@ -174,12 +174,12 @@ function _matfun!(::typeof(exp), A::StridedMatrix{T}) where {T<:BlasFloat} end end - if ilo > 1 # apply lower permutations in reverse order + if ilo > 1 # apply lower permutations in reverse order for j in (ilo - 1):-1:1 LinearAlgebra.rcswap!(j, Int(scale[j]), X) end end - if ihi < n # apply upper permutations in forward order + if ihi < n # apply upper permutations in forward order for j in (ihi + 1):n LinearAlgebra.rcswap!(j, Int(scale[j]), X) end @@ -247,12 +247,12 @@ function _matfun_frechet!( end end - if ilo > 1 # apply lower permutations in reverse order + if ilo > 1 # apply lower permutations in reverse order for j in (ilo - 1):-1:1 LinearAlgebra.rcswap!(j, Int(scale[j]), ∂X) end end - if ihi < n # apply upper permutations in forward order + if ihi < n # apply upper permutations in forward order for j in (ihi + 1):n LinearAlgebra.rcswap!(j, Int(scale[j]), ∂X) end From 8f665ac8c69a9aa56f06c3ef011d9327e99c770f Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 18 Jan 2021 11:37:27 -0800 Subject: [PATCH 15/33] Update src/rulesets/LinearAlgebra/matfun.jl --- src/rulesets/LinearAlgebra/matfun.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/LinearAlgebra/matfun.jl b/src/rulesets/LinearAlgebra/matfun.jl index b5ae252cf..dc727e37e 100644 --- a/src/rulesets/LinearAlgebra/matfun.jl +++ b/src/rulesets/LinearAlgebra/matfun.jl @@ -34,7 +34,7 @@ _matfun_frechet """ _matfun_frechet!(f, A, Y, ΔA, intermediates) -Similar to `_matfun_frechet!`, but where `ΔA` may be overwritten. +Similar to [`_matfun_frechet`](@ref), but where `ΔA` may be overwritten. """ _matfun_frechet! From 9e565ae2c7ebc1458e3f0145e1476fba49225a2c Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 18 Jan 2021 13:04:04 -0800 Subject: [PATCH 16/33] Correctly handle balancing --- src/rulesets/LinearAlgebra/matfun.jl | 79 ++++++++++++++++------------ 1 file changed, 44 insertions(+), 35 deletions(-) diff --git a/src/rulesets/LinearAlgebra/matfun.jl b/src/rulesets/LinearAlgebra/matfun.jl index b5ae252cf..b79951cd8 100644 --- a/src/rulesets/LinearAlgebra/matfun.jl +++ b/src/rulesets/LinearAlgebra/matfun.jl @@ -163,27 +163,7 @@ function _matfun!(::typeof(exp), A::StridedMatrix{T}) where {T<:BlasFloat} end end - # Undo the balancing - for j in ilo:ihi - scj = scale[j] - for i in 1:n - X[j, i] *= scj - end - for i in 1:n - X[i, j] /= scj - end - end - - if ilo > 1 # apply lower permutations in reverse order - for j in (ilo - 1):-1:1 - LinearAlgebra.rcswap!(j, Int(scale[j]), X) - end - end - if ihi < n # apply upper permutations in forward order - for j in (ihi + 1):n - LinearAlgebra.rcswap!(j, Int(scale[j]), X) - end - end + _unbalance!(X, ilo, ihi, scale, n) return X, (ilo, ihi, scale, C, si, Apows, W, F, Xpows) end @@ -196,15 +176,7 @@ function _matfun_frechet!( ::typeof(exp), A::StridedMatrix{T}, X, ΔA, (ilo, ihi, scale, C, si, Apows, W, F, Xpows) ) where {T<:BlasFloat} n = LinearAlgebra.checksquare(A) - for j in ilo:ihi - scj = scale[j] - for i in 1:n - ΔA[j, i] /= scj - end - for i in 1:n - ΔA[i, j] *= scj - end - end + _balance!(ΔA, ilo, ihi, scale, n) if si > 0 ΔA ./= convert(T, 2^si) @@ -237,25 +209,62 @@ function _matfun_frechet!( end end + _unbalance!(∂X, ilo, ihi, scale, n) + return ∂X +end + +##### +##### utils +##### + +# Given (ilo, ihi, iscale) returned by LAPACK.gebal!('B', A), apply same balancing to X +function _balance!(X, ilo, ihi, scale, n) + n = size(X, 1) + if ihi < n + for j in (ihi + 1):n + LinearAlgebra.rcswap!(j, Int(scale[j]), X) + end + end + if ilo > 1 + for j in (ilo - 1):-1:1 + LinearAlgebra.rcswap!(j, Int(scale[j]), X) + end + end + + # Undo the balancing for j in ilo:ihi scj = scale[j] for i in 1:n - ∂X[j, i] *= scj + X[j, i] /= scj end for i in 1:n - ∂X[i, j] /= scj + X[i, j] *= scj + end + end + return X +end + +# Reverse of _balance! +function _unbalance!(X, ilo, ihi, scale, n) + for j in ilo:ihi + scj = scale[j] + for i in 1:n + X[j, i] *= scj + end + for i in 1:n + X[i, j] /= scj end end if ilo > 1 # apply lower permutations in reverse order for j in (ilo - 1):-1:1 - LinearAlgebra.rcswap!(j, Int(scale[j]), ∂X) + LinearAlgebra.rcswap!(j, Int(scale[j]), X) end end if ihi < n # apply upper permutations in forward order for j in (ihi + 1):n - LinearAlgebra.rcswap!(j, Int(scale[j]), ∂X) + LinearAlgebra.rcswap!(j, Int(scale[j]), X) end end - return ∂X + return X end From 71134fdadfa7b1d5562fa4d7b155fccfe3fe9696 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 18 Jan 2021 13:04:14 -0800 Subject: [PATCH 17/33] Test imbalanced matrix A --- test/rulesets/LinearAlgebra/matfun.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test/rulesets/LinearAlgebra/matfun.jl b/test/rulesets/LinearAlgebra/matfun.jl index ac6fe03f7..a7715dc68 100644 --- a/test/rulesets/LinearAlgebra/matfun.jl +++ b/test/rulesets/LinearAlgebra/matfun.jl @@ -9,6 +9,11 @@ nrm in (0.01, 0.1, 0.5, 1.5, 3.0, 6.0, 12.0) A *= nrm / opnorm(A, 1) frule_test(LinearAlgebra.exp!, (A, ΔA)) end + @testset "imbalanced A" begin + A = Float64[0 10 0 0; -1 0 0 0; 0 0 0 0; -2 0 0 0] + ΔA = rand_tangent(A) + frule_test(LinearAlgebra.exp!, (A, ΔA)) + end @testset "hermitian A" begin A, ΔA = Matrix(Hermitian(randn(ComplexF64, n, n))), randn(ComplexF64, n, n) frule_test(LinearAlgebra.exp!, (A, Matrix(Hermitian(ΔA)))) @@ -30,6 +35,12 @@ nrm in (0.01, 0.1, 0.5, 1.5, 3.0, 6.0, 12.0) Y, back = rrule(exp, A) @inferred back(ΔY) end + @testset "imbalanced A" begin + A = Float64[0 10 0 0; -1 0 0 0; 0 0 0 0; -2 0 0 0] + ΔA = rand_tangent(A) + ΔY = rand_tangent(exp(A)) + rrule_test(exp, ΔY, (A, ΔA); check_inferred=false) + end @testset "hermitian A" begin A, ΔA = Matrix(Hermitian(randn(ComplexF64, n, n))), randn(ComplexF64, n, n) ΔY = randn(ComplexF64, n, n) From bd48565a341b8c6aa652769f72708272ea3d7352 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 18 Jan 2021 13:04:37 -0800 Subject: [PATCH 18/33] Increment version number --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 988e5ec32..f62fe3d8c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.48" +version = "0.7.49" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From dc1b1ab47cdbd460e4137efc89614ac2c1adc3ec Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 19 Jan 2021 11:49:12 -0800 Subject: [PATCH 19/33] Apply suggestions from code review Co-authored-by: Lyndon White --- src/rulesets/LinearAlgebra/matfun.jl | 2 ++ test/rulesets/LinearAlgebra/matfun.jl | 7 ++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/rulesets/LinearAlgebra/matfun.jl b/src/rulesets/LinearAlgebra/matfun.jl index c80bd1773..0e6302b29 100644 --- a/src/rulesets/LinearAlgebra/matfun.jl +++ b/src/rulesets/LinearAlgebra/matfun.jl @@ -1,4 +1,5 @@ # matrix functions of dense matrices +# https://en.wikipedia.org/wiki/Matrix_function # NOTE: for matrix functions whose power series representation has real coefficients, # the pullback and pushforward are related by an adjoint. @@ -86,6 +87,7 @@ end ## Destructive matrix exponential using algorithm from Higham, 2008, ## "Functions of Matrices: Theory and Computation", SIAM ## Adapted from LinearAlgebra.exp! with return of intermediates +## https://github.com/JuliaLang/julia/blob/f613b551009a7a9dbe46235929099f2ecd28bed1/stdlib/LinearAlgebra/src/dense.jl#L583-L666 function _matfun!(::typeof(exp), A::StridedMatrix{T}) where {T<:BlasFloat} n = LinearAlgebra.checksquare(A) ilo, ihi, scale = LAPACK.gebal!('B', A) # modifies A diff --git a/test/rulesets/LinearAlgebra/matfun.jl b/test/rulesets/LinearAlgebra/matfun.jl index a7715dc68..2e98d5fd4 100644 --- a/test/rulesets/LinearAlgebra/matfun.jl +++ b/test/rulesets/LinearAlgebra/matfun.jl @@ -2,7 +2,7 @@ @testset "LinearAlgebra.exp!(A::Matrix) frule" begin n = 10 @testset "A::Matrix{$T}, opnorm(A,1)=$nrm" for T in (Float64, ComplexF64), -nrm in (0.01, 0.1, 0.5, 1.5, 3.0, 6.0, 12.0) + nrm in (0.01, 0.1, 0.5, 1.5, 3.0, 6.0, 12.0) A, ΔA = randn(ComplexF64, n, n), randn(ComplexF64, n, n) # choose normalization to hit specific branch @@ -15,7 +15,8 @@ nrm in (0.01, 0.1, 0.5, 1.5, 3.0, 6.0, 12.0) frule_test(LinearAlgebra.exp!, (A, ΔA)) end @testset "hermitian A" begin - A, ΔA = Matrix(Hermitian(randn(ComplexF64, n, n))), randn(ComplexF64, n, n) + A = Matrix(Hermitian(randn(ComplexF64, n, n))) + ΔA = randn(ComplexF64, n, n) frule_test(LinearAlgebra.exp!, (A, Matrix(Hermitian(ΔA)))) frule_test(LinearAlgebra.exp!, (A, ΔA)) end @@ -24,7 +25,7 @@ nrm in (0.01, 0.1, 0.5, 1.5, 3.0, 6.0, 12.0) @testset "exp(A::Matrix) rrule" begin n = 10 @testset "A::Matrix{$T}, opnorm(A,1)=$nrm" for T in (Float64, ComplexF64), -nrm in (0.01, 0.1, 0.5, 1.5, 3.0, 6.0, 12.0) + nrm in (0.01, 0.1, 0.5, 1.5, 3.0, 6.0, 12.0) A, ΔA = randn(ComplexF64, n, n), randn(ComplexF64, n, n) ΔY = randn(ComplexF64, n, n) From 57aea176ccc4aee151930a683391c86a7a2bfbd1 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 19 Jan 2021 16:15:27 -0800 Subject: [PATCH 20/33] Change signature of _matfun_frechet --- src/rulesets/LinearAlgebra/matfun.jl | 8 ++++---- src/rulesets/LinearAlgebra/symmetric.jl | 11 +++++------ 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/rulesets/LinearAlgebra/matfun.jl b/src/rulesets/LinearAlgebra/matfun.jl index 0e6302b29..1ae6f968e 100644 --- a/src/rulesets/LinearAlgebra/matfun.jl +++ b/src/rulesets/LinearAlgebra/matfun.jl @@ -25,7 +25,7 @@ Similar to [`_matfun`](@ref), but where `A` may be overwritten. _matfun! """ - _matfun_frechet(f, A, Y, ΔA, intermediates) + _matfun_frechet(f, ΔA, A, Y, intermediates) Compute the Fréchet derivative of the matrix function `Y=f(A)`, where the Fréchet derivative of `A` is `ΔA`, and `intermediates` is the second argument returned by `_matfun`. @@ -33,7 +33,7 @@ of `A` is `ΔA`, and `intermediates` is the second argument returned by `_matfun _matfun_frechet """ - _matfun_frechet!(f, A, Y, ΔA, intermediates) + _matfun_frechet!(f, ΔA, A, Y, intermediates) Similar to [`_matfun_frechet`](@ref), but where `ΔA` may be overwritten. """ @@ -54,7 +54,7 @@ function frule((_, ΔA), ::typeof(LinearAlgebra.exp!), A::StridedMatrix{<:BlasFl end else X, intermediates = _matfun!(exp, A) - ∂X = _matfun_frechet!(exp, A, X, ΔA, intermediates) + ∂X = _matfun_frechet!(exp, ΔA, A, X, intermediates) end return X, ∂X end @@ -175,7 +175,7 @@ end # Condition Number Estimation", SIAM. 30 (4). pp. 1639-1657. # http://eprints.maths.manchester.ac.uk/id/eprint/1218 function _matfun_frechet!( - ::typeof(exp), A::StridedMatrix{T}, X, ΔA, (ilo, ihi, scale, C, si, Apows, W, F, Xpows) + ::typeof(exp), ΔA, A::StridedMatrix{T}, X, (ilo, ihi, scale, C, si, Apows, W, F, Xpows) ) where {T<:BlasFloat} n = LinearAlgebra.checksquare(A) _balance!(ΔA, ilo, ihi, scale, n) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index a9d6d3e2b..7def77cac 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -292,7 +292,7 @@ for func in (:exp, :log, :sqrt, :cos, :sin, :tan, :cosh, :sinh, :tanh, :acos, :a @eval begin function frule((_, ΔA), ::typeof($func), A::LinearAlgebra.RealHermSymComplexHerm) Y, intermediates = _matfun($func, A) - Ȳ = _matfun_frechet($func, A, Y, ΔA, intermediates) + Ȳ = _matfun_frechet($func, ΔA, A, Y, intermediates) # If ΔA was hermitian, then ∂Y has the same structure as Y ∂Y = if ishermitian(ΔA) && (isa(Y, Symmetric) || isa(Y, Hermitian)) _symhermlike!(Ȳ, Y) @@ -308,8 +308,7 @@ for func in (:exp, :log, :sqrt, :cos, :sin, :tan, :cosh, :sinh, :tanh, :acos, :a # for Hermitian Y, we don't need to realify the diagonal of ΔY, since the # effect is the same as applying _hermitrizelike! at the end ∂Y = eltype(Y) <: Real ? real(ΔY) : ΔY - # for matrix functions, the pullback is related to the pushforward by an adjoint - Ā = _matfun_frechet($func, A, Y, ∂Y', intermediates) + Ā = _matfun_frechet_adjoint($func, ∂Y, A, Y, intermediates) # the cotangent of Hermitian A should be Hermitian ∂A = _hermitrizelike!(Ā, A) return NO_FIELDS, ∂A @@ -344,9 +343,9 @@ function rrule(::typeof(sincos), A::LinearAlgebra.RealHermSymComplexHerm) ΔsinA, ΔcosA = real(ΔsinA), real(ΔcosA) end if ΔcosA isa AbstractZero - Ā = _matfun_frechet(sin, A, sinA, ΔsinA, (λ, U, sinλ, cosλ)) + Ā = _matfun_frechet_adjoint(sin, ΔsinA, A, sinA, (λ, U, sinλ, cosλ)) elseif ΔsinA isa AbstractZero - Ā = _matfun_frechet(cos, A, cosA, ΔcosA, (λ, U, cosλ, -sinλ)) + Ā = _matfun_frechet_adjoint(cos, ΔcosA, A, cosA, (λ, U, cosλ, -sinλ)) else # we will overwrite tmp with various temporary values during this computation tmp = mul!(similar(U, Base.promote_eltype(U, ΔsinA, ΔcosA)), ΔsinA, U) @@ -392,7 +391,7 @@ function _matfun(f, A::LinearAlgebra.RealHermSymComplexHerm) end # Computes ∂Y = U * (P .* (U' * ΔA * U)) * U' with fewer allocations -function _matfun_frechet(f, A::LinearAlgebra.RealHermSymComplexHerm, Y, ΔA, (λ, U, fλ, df_dλ)) +function _matfun_frechet(f, ΔA, A::LinearAlgebra.RealHermSymComplexHerm, Y, (λ, U, fλ, df_dλ)) # We will overwrite tmp matrix several times to hold different values tmp = mul!(similar(U, Base.promote_eltype(U, ΔA)), ΔA, U) ∂Λ = mul!(similar(tmp), U', tmp) From e2e66052b4dc94b56c43f56c1ddf0f9859192a1f Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 19 Jan 2021 16:16:19 -0800 Subject: [PATCH 21/33] Give math for Frechet derivative --- src/rulesets/LinearAlgebra/matfun.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/rulesets/LinearAlgebra/matfun.jl b/src/rulesets/LinearAlgebra/matfun.jl index 1ae6f968e..d8f45e1bd 100644 --- a/src/rulesets/LinearAlgebra/matfun.jl +++ b/src/rulesets/LinearAlgebra/matfun.jl @@ -29,6 +29,10 @@ _matfun! Compute the Fréchet derivative of the matrix function `Y=f(A)`, where the Fréchet derivative of `A` is `ΔA`, and `intermediates` is the second argument returned by `_matfun`. +The Fréchet derivative is defined as +```math +L(ΔA) = \\lim_{t → 0} \\frac{f(A + t ΔA) - f(A)}{t} +``` """ _matfun_frechet From 976af09d23d084caf5eac5123bce7f8ff0894239 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 19 Jan 2021 17:16:45 -0800 Subject: [PATCH 22/33] Change Frechet notation --- src/rulesets/LinearAlgebra/matfun.jl | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/rulesets/LinearAlgebra/matfun.jl b/src/rulesets/LinearAlgebra/matfun.jl index d8f45e1bd..68e60c477 100644 --- a/src/rulesets/LinearAlgebra/matfun.jl +++ b/src/rulesets/LinearAlgebra/matfun.jl @@ -25,21 +25,22 @@ Similar to [`_matfun`](@ref), but where `A` may be overwritten. _matfun! """ - _matfun_frechet(f, ΔA, A, Y, intermediates) + _matfun_frechet(f, E, A, Y, intermediates) -Compute the Fréchet derivative of the matrix function `Y=f(A)`, where the Fréchet derivative -of `A` is `ΔA`, and `intermediates` is the second argument returned by `_matfun`. -The Fréchet derivative is defined as +Compute the Fréchet derivative of the matrix function ``Y = f(A)`` at ``A`` in the direction +of ``E``, where `intermediates` is the second argument returned by [`_matfun`](@ref). + +The Fréchet derivative is the unique linear map ``L_f \\colon E → L_f(A, E)``, such that ```math -L(ΔA) = \\lim_{t → 0} \\frac{f(A + t ΔA) - f(A)}{t} +L_f(A, E) = f(A + E) - f(A) + o(\\lVert E \\rVert). ``` """ _matfun_frechet """ - _matfun_frechet!(f, ΔA, A, Y, intermediates) + _matfun_frechet!(f, E, A, Y, intermediates) -Similar to [`_matfun_frechet`](@ref), but where `ΔA` may be overwritten. +Similar to [`_matfun_frechet`](@ref), but where `E` may be overwritten. """ _matfun_frechet! From d7d20ba66e7a0ce49e0131224bf7a01d7aa5e10e Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 19 Jan 2021 17:17:25 -0800 Subject: [PATCH 23/33] Add _matfun_frechet_adjoint --- src/rulesets/LinearAlgebra/matfun.jl | 43 +++++++++++++++++++++++++--- 1 file changed, 39 insertions(+), 4 deletions(-) diff --git a/src/rulesets/LinearAlgebra/matfun.jl b/src/rulesets/LinearAlgebra/matfun.jl index 68e60c477..071d9d0fb 100644 --- a/src/rulesets/LinearAlgebra/matfun.jl +++ b/src/rulesets/LinearAlgebra/matfun.jl @@ -44,6 +44,43 @@ Similar to [`_matfun_frechet`](@ref), but where `E` may be overwritten. """ _matfun_frechet! +""" + _matfun_frechet_adjoint(f, E, A, Y, intermediates) + +Compute the adjoint of the Fréchet derivative of the matrix function ``Y = f(A)`` at ``A`` +in the direction of ``E``, where `intermediates` is the second argument returned by +[`_matfun`](@ref). + +Given the Fréchet ``L_f(A, E)`` computed by [`_matfun_frechet`](@ref), then its adjoint +``L_f^⋆(A, E)`` is defined by the identity +```math +\\langle B, L_f(A, C) \\rangle = \\langle L_f^⋆(A, B), C \\rangle. +``` +This identity is satisfied by ``L_f^⋆(A, E) = L_f(A, E')'``. +""" +function _matfun_frechet_adjoint(f, E, A, Y, intermediates) + E′ = E' + # avoid passing an Adjoint to _matfun_frechet in case it can't handle it + E′ = E′ isa Adjoint ? copy(E′) : E′ + LE = adjoint(_matfun_frechet(f, E′, A, Y, intermediates)) + # avoid returning an Adjoint + return LE isa Adjoint ? copy(LE) : LE +end + +""" + _matfun_frechet_adjoint!(f, E, A, Y, intermediates) + +Similar to [`_matfun_frechet_adjoint`](@ref), but where `E` may be overwritten. +""" +function _matfun_frechet_adjoint!(f, E, A, Y, intermediates) + E′ = E' + # avoid passing an Adjoint to _matfun_frechet in case it can't handle it + E′ = E′ isa Adjoint ? copy(E′) : E′ + LE = adjoint(_matfun_frechet!(f, E′, A, Y, intermediates)) + # avoid returning an Adjoint + return LE isa Adjoint ? copy(LE) : LE +end + ##### ##### `exp`/`exp!` ##### @@ -71,18 +108,16 @@ function rrule(::typeof(exp), A0::StridedMatrix{<:BlasFloat}) hermA = Hermitian(A0) hermX, hermX_intermediates = _matfun(exp, hermA) function exp_pullback_hermitian(ΔX) - ∂hermA = _matfun_frechet(exp, hermA, hermX, ΔX, hermX_intermediates) ∂hermA isa LinearAlgebra.RealHermSymComplexHerm || return NO_FIELDS, ∂hermA return NO_FIELDS, parent(∂hermA) + ∂hermA = _matfun_frechet_adjoint(exp, ΔX, hermA, hermX, hermX_intermediates) end return LinearAlgebra.copytri!(parent(hermX), 'U', true), exp_pullback_hermitian else A = copy(A0) X, intermediates = _matfun!(exp, A) function exp_pullback(ΔX) - ΔX′ = copy(adjoint(ΔX)) - ∂A′ = _matfun_frechet!(exp, A, X, ΔX′, intermediates) - ∂A = copy(adjoint(∂A′)) + ∂A = _matfun_frechet_adjoint!(exp, ΔX, A, X, intermediates) return NO_FIELDS, ∂A end return X, exp_pullback From b0ae61ce42a538e2c0dd94ccd373a8e35bc288a0 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 19 Jan 2021 17:17:42 -0800 Subject: [PATCH 24/33] Simplify hermitian code --- src/rulesets/LinearAlgebra/matfun.jl | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/rulesets/LinearAlgebra/matfun.jl b/src/rulesets/LinearAlgebra/matfun.jl index 071d9d0fb..849e7540e 100644 --- a/src/rulesets/LinearAlgebra/matfun.jl +++ b/src/rulesets/LinearAlgebra/matfun.jl @@ -87,13 +87,11 @@ end function frule((_, ΔA), ::typeof(LinearAlgebra.exp!), A::StridedMatrix{<:BlasFloat}) if ishermitian(A) - hermX, ∂hermX = frule((Zero(), ΔA), exp, Hermitian(A)) - X = LinearAlgebra.copytri!(parent(hermX), 'U', true) - if ∂hermX isa LinearAlgebra.RealHermSymComplexHerm - ∂X = LinearAlgebra.copytri!(parent(∂hermX), 'U', true) - else - ∂X = ∂hermX - end + hermA = Hermitian(A) + hermX, intermediates = _matfun(exp, hermA) + ∂hermX = _matfun_frechet(exp, ΔA, hermA, hermX, intermediates) + X = Matrix(hermX) + ∂X = Matrix(∂hermX) else X, intermediates = _matfun!(exp, A) ∂X = _matfun_frechet!(exp, ΔA, A, X, intermediates) @@ -108,11 +106,10 @@ function rrule(::typeof(exp), A0::StridedMatrix{<:BlasFloat}) hermA = Hermitian(A0) hermX, hermX_intermediates = _matfun(exp, hermA) function exp_pullback_hermitian(ΔX) - ∂hermA isa LinearAlgebra.RealHermSymComplexHerm || return NO_FIELDS, ∂hermA - return NO_FIELDS, parent(∂hermA) ∂hermA = _matfun_frechet_adjoint(exp, ΔX, hermA, hermX, hermX_intermediates) + return NO_FIELDS, Matrix(∂hermA) end - return LinearAlgebra.copytri!(parent(hermX), 'U', true), exp_pullback_hermitian + return Matrix(hermX), exp_pullback_hermitian else A = copy(A0) X, intermediates = _matfun!(exp, A) From 62b963b6615e9d7fa4bc4dda267f5123c87e5c38 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 19 Jan 2021 17:22:39 -0800 Subject: [PATCH 25/33] Correct comment --- src/rulesets/LinearAlgebra/matfun.jl | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/rulesets/LinearAlgebra/matfun.jl b/src/rulesets/LinearAlgebra/matfun.jl index 849e7540e..cce77df77 100644 --- a/src/rulesets/LinearAlgebra/matfun.jl +++ b/src/rulesets/LinearAlgebra/matfun.jl @@ -1,12 +1,17 @@ # matrix functions of dense matrices # https://en.wikipedia.org/wiki/Matrix_function -# NOTE: for matrix functions whose power series representation has real coefficients, -# the pullback and pushforward are related by an adjoint. -# Specifically, if the pushforward of f(A) is (f_*)_A(ΔA), then the pullback at Y=f(A) is -# (f^*)_Y(ΔY) = (f_*)_{A'}(ΔY) = ((f_*)_A(ΔY'))' +# NOTE: for a matrix function f, the pushforward and pullback can be computed using the +# Fréchet derivative and its adjoint, respectively. +# The pushforwards and pullbacks are related by matrix adjoints. +# Specifically, if the pushforward of f(A) at A is (f_*)_A(ΔA), then the pullback at A is +# (f^*)_A(ΔY) = ((f_*)_A(ΔY'))'. +# If f has a power series representation with real coefficients, then this simplifies to +# (f^*)_Y(ΔY) = (f_*)_{A'}(ΔY) # So we reuse the code from the pushforward to implement the pullback. +# interface function definitions + """ _matfun(f, A) -> (Y, intermediates) From 87e4c5315cfa6fb524a327dd45177017c4e19235 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 19 Jan 2021 17:23:32 -0800 Subject: [PATCH 26/33] Remove comments --- src/rulesets/LinearAlgebra/matfun.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/rulesets/LinearAlgebra/matfun.jl b/src/rulesets/LinearAlgebra/matfun.jl index cce77df77..1e7ec4b6a 100644 --- a/src/rulesets/LinearAlgebra/matfun.jl +++ b/src/rulesets/LinearAlgebra/matfun.jl @@ -275,7 +275,6 @@ function _balance!(X, ilo, ihi, scale, n) end end - # Undo the balancing for j in ilo:ihi scj = scale[j] for i in 1:n @@ -300,12 +299,12 @@ function _unbalance!(X, ilo, ihi, scale, n) end end - if ilo > 1 # apply lower permutations in reverse order + if ilo > 1 for j in (ilo - 1):-1:1 LinearAlgebra.rcswap!(j, Int(scale[j]), X) end end - if ihi < n # apply upper permutations in forward order + if ihi < n for j in (ihi + 1):n LinearAlgebra.rcswap!(j, Int(scale[j]), X) end From 9bd06b152a200db5a72cc4cdd6659645738ce4ef Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 19 Jan 2021 17:26:19 -0800 Subject: [PATCH 27/33] Use abbreviated SHA --- src/rulesets/LinearAlgebra/matfun.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/LinearAlgebra/matfun.jl b/src/rulesets/LinearAlgebra/matfun.jl index 1e7ec4b6a..58bedb3af 100644 --- a/src/rulesets/LinearAlgebra/matfun.jl +++ b/src/rulesets/LinearAlgebra/matfun.jl @@ -129,7 +129,7 @@ end ## Destructive matrix exponential using algorithm from Higham, 2008, ## "Functions of Matrices: Theory and Computation", SIAM ## Adapted from LinearAlgebra.exp! with return of intermediates -## https://github.com/JuliaLang/julia/blob/f613b551009a7a9dbe46235929099f2ecd28bed1/stdlib/LinearAlgebra/src/dense.jl#L583-L666 +## https://github.com/JuliaLang/julia/blob/f613b55/stdlib/LinearAlgebra/src/dense.jl#L583-L666 function _matfun!(::typeof(exp), A::StridedMatrix{T}) where {T<:BlasFloat} n = LinearAlgebra.checksquare(A) ilo, ihi, scale = LAPACK.gebal!('B', A) # modifies A From 2ed06e63b67268440470917c5823e8269391b1b4 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 19 Jan 2021 21:06:55 -0800 Subject: [PATCH 28/33] Link --- src/rulesets/LinearAlgebra/matfun.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/LinearAlgebra/matfun.jl b/src/rulesets/LinearAlgebra/matfun.jl index 58bedb3af..75535d394 100644 --- a/src/rulesets/LinearAlgebra/matfun.jl +++ b/src/rulesets/LinearAlgebra/matfun.jl @@ -17,7 +17,7 @@ Compute the matrix function `Y=f(A)` for matrix `A`. The function returns a tuple containing the result and a tuple of intermediates to be -reused by `_matfun_frechet` to compute the Fréchet derivative. +reused by [`_matfun_frechet`](@ref) to compute the Fréchet derivative. Note that any function `f` used with this **must** have a `frule` defined on it. """ _matfun From 156e6f5ec55b8402f5d735e80eaa2751783dc057 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 19 Jan 2021 21:09:50 -0800 Subject: [PATCH 29/33] Update comment --- src/rulesets/LinearAlgebra/matfun.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/rulesets/LinearAlgebra/matfun.jl b/src/rulesets/LinearAlgebra/matfun.jl index 75535d394..cefa6cc24 100644 --- a/src/rulesets/LinearAlgebra/matfun.jl +++ b/src/rulesets/LinearAlgebra/matfun.jl @@ -3,9 +3,9 @@ # NOTE: for a matrix function f, the pushforward and pullback can be computed using the # Fréchet derivative and its adjoint, respectively. -# The pushforwards and pullbacks are related by matrix adjoints. -# Specifically, if the pushforward of f(A) at A is (f_*)_A(ΔA), then the pullback at A is -# (f^*)_A(ΔY) = ((f_*)_A(ΔY'))'. +# https://en.wikipedia.org/wiki/Fréchet_derivative +# The pushforwards and pullbacks are related by matrix adjoints. If the pushforward of f(A) +# at A is (f_*)_A(ΔA), then the pullback at A is (f^*)_A(ΔY) = ((f_*)_A(ΔY'))'. # If f has a power series representation with real coefficients, then this simplifies to # (f^*)_Y(ΔY) = (f_*)_{A'}(ΔY) # So we reuse the code from the pushforward to implement the pullback. From 9a63d13704213dff23d1857006584b0c0ee5d93d Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 19 Jan 2021 21:18:56 -0800 Subject: [PATCH 30/33] Move comment up --- test/rulesets/LinearAlgebra/matfun.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/rulesets/LinearAlgebra/matfun.jl b/test/rulesets/LinearAlgebra/matfun.jl index 2e98d5fd4..23eb4655e 100644 --- a/test/rulesets/LinearAlgebra/matfun.jl +++ b/test/rulesets/LinearAlgebra/matfun.jl @@ -2,10 +2,10 @@ @testset "LinearAlgebra.exp!(A::Matrix) frule" begin n = 10 @testset "A::Matrix{$T}, opnorm(A,1)=$nrm" for T in (Float64, ComplexF64), + # choose normalization to hit specific branch nrm in (0.01, 0.1, 0.5, 1.5, 3.0, 6.0, 12.0) A, ΔA = randn(ComplexF64, n, n), randn(ComplexF64, n, n) - # choose normalization to hit specific branch A *= nrm / opnorm(A, 1) frule_test(LinearAlgebra.exp!, (A, ΔA)) end @@ -25,11 +25,11 @@ @testset "exp(A::Matrix) rrule" begin n = 10 @testset "A::Matrix{$T}, opnorm(A,1)=$nrm" for T in (Float64, ComplexF64), + # choose normalization to hit specific branch nrm in (0.01, 0.1, 0.5, 1.5, 3.0, 6.0, 12.0) A, ΔA = randn(ComplexF64, n, n), randn(ComplexF64, n, n) ΔY = randn(ComplexF64, n, n) - # choose normalization to hit specific branch A *= nrm / opnorm(A, 1) # rrule is not inferrable, but pullback should be rrule_test(exp, ΔY, (A, ΔA); check_inferred=false) From 5ba193d105083aee08300e06973e00fc01bf5721 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 19 Jan 2021 21:21:23 -0800 Subject: [PATCH 31/33] Move comment further up --- test/rulesets/LinearAlgebra/matfun.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/rulesets/LinearAlgebra/matfun.jl b/test/rulesets/LinearAlgebra/matfun.jl index 23eb4655e..065574611 100644 --- a/test/rulesets/LinearAlgebra/matfun.jl +++ b/test/rulesets/LinearAlgebra/matfun.jl @@ -1,8 +1,8 @@ @testset "matrix functions" begin @testset "LinearAlgebra.exp!(A::Matrix) frule" begin n = 10 + # each normalization hits a specific branch @testset "A::Matrix{$T}, opnorm(A,1)=$nrm" for T in (Float64, ComplexF64), - # choose normalization to hit specific branch nrm in (0.01, 0.1, 0.5, 1.5, 3.0, 6.0, 12.0) A, ΔA = randn(ComplexF64, n, n), randn(ComplexF64, n, n) @@ -24,8 +24,8 @@ @testset "exp(A::Matrix) rrule" begin n = 10 + # each normalization hits a specific branch @testset "A::Matrix{$T}, opnorm(A,1)=$nrm" for T in (Float64, ComplexF64), - # choose normalization to hit specific branch nrm in (0.01, 0.1, 0.5, 1.5, 3.0, 6.0, 12.0) A, ΔA = randn(ComplexF64, n, n), randn(ComplexF64, n, n) From 49df929316decf817c7aef459272b414f753350e Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 19 Jan 2021 21:33:40 -0800 Subject: [PATCH 32/33] Update docstrings --- src/rulesets/LinearAlgebra/matfun.jl | 10 +++++++++- src/rulesets/LinearAlgebra/symmetric.jl | 3 ++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/rulesets/LinearAlgebra/matfun.jl b/src/rulesets/LinearAlgebra/matfun.jl index cefa6cc24..87d5d371f 100644 --- a/src/rulesets/LinearAlgebra/matfun.jl +++ b/src/rulesets/LinearAlgebra/matfun.jl @@ -4,6 +4,7 @@ # NOTE: for a matrix function f, the pushforward and pullback can be computed using the # Fréchet derivative and its adjoint, respectively. # https://en.wikipedia.org/wiki/Fréchet_derivative + # The pushforwards and pullbacks are related by matrix adjoints. If the pushforward of f(A) # at A is (f_*)_A(ΔA), then the pullback at A is (f^*)_A(ΔY) = ((f_*)_A(ΔY'))'. # If f has a power series representation with real coefficients, then this simplifies to @@ -18,7 +19,6 @@ Compute the matrix function `Y=f(A)` for matrix `A`. The function returns a tuple containing the result and a tuple of intermediates to be reused by [`_matfun_frechet`](@ref) to compute the Fréchet derivative. -Note that any function `f` used with this **must** have a `frule` defined on it. """ _matfun @@ -39,6 +39,10 @@ The Fréchet derivative is the unique linear map ``L_f \\colon E → L_f(A, E)`` ```math L_f(A, E) = f(A + E) - f(A) + o(\\lVert E \\rVert). ``` + +[^Higham08]: + > Higham, Nicholas J. Chapter 3: Conditioning. Functions of Matrices. 2008, 55-70. + > doi: 10.1137/1.9780898717778.ch3 """ _matfun_frechet @@ -62,6 +66,10 @@ Given the Fréchet ``L_f(A, E)`` computed by [`_matfun_frechet`](@ref), then its \\langle B, L_f(A, C) \\rangle = \\langle L_f^⋆(A, B), C \\rangle. ``` This identity is satisfied by ``L_f^⋆(A, E) = L_f(A, E')'``. + +[^Higham08]: + > Higham, Nicholas J. Chapter 3: Conditioning. Functions of Matrices. 2008, 55-70. + > doi: 10.1137/1.9780898717778.ch3 """ function _matfun_frechet_adjoint(f, E, A, Y, intermediates) E′ = E' diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index 7def77cac..9471b3ba5 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -366,7 +366,8 @@ end Compute the matrix function `f(A)` for real or complex hermitian `A`. The function returns a tuple containing the result and a tuple of intermediates to be -reused by `_matfun_frechet` to compute the Fréchet derivative. +reused by [`_matfun_frechet`](@ref) to compute the Fréchet derivative. + Note any function `f` used with this **must** have a `frule` defined on it. """ function _matfun(f, A::LinearAlgebra.RealHermSymComplexHerm) From 8c27276f5199d64f6af8fc1cdeeb4e62d0843839 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 19 Jan 2021 21:34:44 -0800 Subject: [PATCH 33/33] Push header to same level as rules --- src/rulesets/LinearAlgebra/matfun.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/rulesets/LinearAlgebra/matfun.jl b/src/rulesets/LinearAlgebra/matfun.jl index 87d5d371f..2b909ac89 100644 --- a/src/rulesets/LinearAlgebra/matfun.jl +++ b/src/rulesets/LinearAlgebra/matfun.jl @@ -11,7 +11,9 @@ # (f^*)_Y(ΔY) = (f_*)_{A'}(ΔY) # So we reuse the code from the pushforward to implement the pullback. -# interface function definitions +##### +##### interface function definitions +##### """ _matfun(f, A) -> (Y, intermediates)