Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rules for dense matrix exponential #351

Merged
merged 34 commits into from
Jan 20, 2021
Merged

Conversation

sethaxen
Copy link
Member

Fixes #331

Because I plan to add a number of matrix functions, and they are long, I have added these rules to a new function matfun.jl.
These rules unfortunately require quite a bit of code duplication from LinearAlgebra, but following discussion in JuliaLang/julia#5840, that's only unavoidable if we refactor these functions in LinearAlgebra to return their intermediates.

# NOTE: for matrix functions whose power series representation has real coefficients,
# the pullback and pushforward are related by an adjoint.
# Specifically, if the pushforward of f(A) is (f_*)_A(ΔA), then the pullback at Y=f(A) is
# (f^*)_Y(ΔY) = (f_*)_{A'}(ΔY) = ((f_*)_A(ΔY'))'
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well that is hideous, but notation is hard, and harder in unicode.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unicode's missing subscripts make it extra hard.

Idea that might not be worth doing:
What if we just made a section for this in the docs, (maybe as internal notes or something)
and wrote the latex and then linked to that?

But yeah notation for pullbacks and pushforwards is hard.
It has to convey so much state

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This bit makes sense:
(f^*)_Y(ΔY) = ((f_*)_A(ΔY'))'
so the pullback at A, i.e. the pullback from Y (though that's not well defined since not all functions are monotonic?)
is equal to the the adjoint of the pushing forward at A, the adjoint of of the output senstivity.
the fact that that is also equal to (f_*)_{A'}(ΔY) is pretty magic.

Magical expodential symmetry? (I feel like i made the same suprised sounds for the same reason on your last PR)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we just made a section for this in the docs, (maybe as internal notes or something)
and wrote the latex and then linked to that?

Hm, that's an idea. I'll consider it, potentially for a future PR.

This bit makes sense:
(f^*)_Y(ΔY) = ((f_*)_A(ΔY'))'
so the pullback at A, i.e. the pullback from Y...is equal to the the adjoint of the pushing forward at A, the adjoint of of the output senstivity.

Ah yes your description is correct (although it's the adjoint of the pushing forward of the adjoint). I just checked Lee, and this should be the right notation:

Suggested change
# (f^*)_Y(ΔY) = (f_*)_{A'}(ΔY) = ((f_*)_A(ΔY'))'
# (f^*)_A(ΔY) = (f_*)_{A'}(ΔY) = ((f_*)_A(ΔY'))'

(though that's not well defined since not all functions are monotonic?)

I'm not sure what you mean by this.

Magical expodential symmetry? (I feel like i made the same suprised sounds for the same reason on your last PR)

It's still surprising to me. Although this property is general for all of the matrix functions defined in LinearAlgebra, not just exp. It doesn't follow for all matrix functions though, just those whose convergent power series have real coefficients.

Copy link
Member

@oxinabox oxinabox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I started to comment on the formatting, but its a bit much.
I think just run https://github.com/domluna/JuliaFormatter.jl/ over it
format_file("matfun.jl", BlueStyle())
its pretty good.

I will review after that, since then I won't be spending time on the basic stuff.

src/rulesets/LinearAlgebra/matfun.jl Outdated Show resolved Hide resolved
src/rulesets/LinearAlgebra/matfun.jl Outdated Show resolved Hide resolved
src/rulesets/LinearAlgebra/matfun.jl Show resolved Hide resolved
src/rulesets/LinearAlgebra/matfun.jl Outdated Show resolved Hide resolved
src/rulesets/LinearAlgebra/matfun.jl Outdated Show resolved Hide resolved
src/rulesets/LinearAlgebra/matfun.jl Outdated Show resolved Hide resolved
src/rulesets/LinearAlgebra/matfun.jl Outdated Show resolved Hide resolved
src/rulesets/LinearAlgebra/matfun.jl Outdated Show resolved Hide resolved
src/rulesets/LinearAlgebra/matfun.jl Outdated Show resolved Hide resolved
src/rulesets/LinearAlgebra/matfun.jl Outdated Show resolved Hide resolved
@sethaxen
Copy link
Member Author

sethaxen commented Jan 18, 2021

I started to comment on the formatting, but its a bit much.

Yeah, it's the formatting used by exp! in LinearAlgebra, which is not great.

I think just run https://github.com/domluna/JuliaFormatter.jl/ over it
format_file("matfun.jl", BlueStyle())

Done!

@codecov-io
Copy link

codecov-io commented Jan 18, 2021

Codecov Report

Merging #351 (8c27276) into master (9004ee0) will decrease coverage by 11.21%.
The diff coverage is 100.00%.

Impacted file tree graph

@@             Coverage Diff             @@
##           master     #351       +/-   ##
===========================================
- Coverage   97.64%   86.43%   -11.22%     
===========================================
  Files          18       19        +1     
  Lines        1231     1172       -59     
===========================================
- Hits         1202     1013      -189     
- Misses         29      159      +130     
Impacted Files Coverage Δ
src/ChainRules.jl 66.66% <ø> (-33.34%) ⬇️
src/rulesets/LinearAlgebra/matfun.jl 100.00% <100.00%> (ø)
src/rulesets/LinearAlgebra/symmetric.jl 83.15% <100.00%> (-15.55%) ⬇️
src/rulesets/Base/evalpoly.jl 0.00% <0.00%> (-97.68%) ⬇️
src/rulesets/Base/utils.jl 0.00% <0.00%> (-80.00%) ⬇️
src/rulesets/Statistics/statistics.jl 66.66% <0.00%> (-23.34%) ⬇️
src/rulesets/LinearAlgebra/utils.jl 66.66% <0.00%> (-20.00%) ⬇️
src/rulesets/LinearAlgebra/structured.jl 92.04% <0.00%> (-6.84%) ⬇️
... and 10 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 9004ee0...8c27276. Read the comment docs.

test/rulesets/LinearAlgebra/matfun.jl Outdated Show resolved Hide resolved
test/rulesets/LinearAlgebra/matfun.jl Outdated Show resolved Hide resolved
src/rulesets/LinearAlgebra/matfun.jl Show resolved Hide resolved
# NOTE: for matrix functions whose power series representation has real coefficients,
# the pullback and pushforward are related by an adjoint.
# Specifically, if the pushforward of f(A) is (f_*)_A(ΔA), then the pullback at Y=f(A) is
# (f^*)_Y(ΔY) = (f_*)_{A'}(ΔY) = ((f_*)_A(ΔY'))'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unicode's missing subscripts make it extra hard.

Idea that might not be worth doing:
What if we just made a section for this in the docs, (maybe as internal notes or something)
and wrote the latex and then linked to that?

But yeah notation for pullbacks and pushforwards is hard.
It has to convey so much state

# NOTE: for matrix functions whose power series representation has real coefficients,
# the pullback and pushforward are related by an adjoint.
# Specifically, if the pushforward of f(A) is (f_*)_A(ΔA), then the pullback at Y=f(A) is
# (f^*)_Y(ΔY) = (f_*)_{A'}(ΔY) = ((f_*)_A(ΔY'))'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This bit makes sense:
(f^*)_Y(ΔY) = ((f_*)_A(ΔY'))'
so the pullback at A, i.e. the pullback from Y (though that's not well defined since not all functions are monotonic?)
is equal to the the adjoint of the pushing forward at A, the adjoint of of the output senstivity.
the fact that that is also equal to (f_*)_{A'}(ΔY) is pretty magic.

Magical expodential symmetry? (I feel like i made the same suprised sounds for the same reason on your last PR)

src/rulesets/LinearAlgebra/matfun.jl Show resolved Hide resolved
src/rulesets/LinearAlgebra/matfun.jl Show resolved Hide resolved
src/rulesets/LinearAlgebra/matfun.jl Outdated Show resolved Hide resolved
test/rulesets/LinearAlgebra/matfun.jl Outdated Show resolved Hide resolved
test/rulesets/LinearAlgebra/matfun.jl Outdated Show resolved Hide resolved
@sethaxen
Copy link
Member Author

@oxinabox I added a _matfun_frechet_adjoint, changed the signature to put the differential first, and substantially modified the docstrings and comments. Would you mind re-reviewing the comments and docstrings at the top of matfun.jl before I merge?

Copy link
Member

@oxinabox oxinabox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@sethaxen sethaxen merged commit b92da50 into JuliaDiff:master Jan 20, 2021
@sethaxen sethaxen deleted the exp2 branch January 20, 2021 18:34
bors bot added a commit to FluxML/Zygote.jl that referenced this pull request Jan 21, 2021
888: Remove rules for matrix exponential r=DhairyaLGandhi a=sethaxen

JuliaDiff/ChainRules.jl#351 added rules for the dense matrix exponential to ChainRules. This PR removes the corresponding adjoint from Zygote.

Co-authored-by: Seth Axen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add rules for the matrix exponential
4 participants