-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add rules for dense matrix exponential #351
Conversation
src/rulesets/LinearAlgebra/matfun.jl
Outdated
# NOTE: for matrix functions whose power series representation has real coefficients, | ||
# the pullback and pushforward are related by an adjoint. | ||
# Specifically, if the pushforward of f(A) is (f_*)_A(ΔA), then the pullback at Y=f(A) is | ||
# (f^*)_Y(ΔY) = (f_*)_{A'}(ΔY) = ((f_*)_A(ΔY'))' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well that is hideous, but notation is hard, and harder in unicode.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unicode's missing subscripts make it extra hard.
Idea that might not be worth doing:
What if we just made a section for this in the docs, (maybe as internal notes or something)
and wrote the latex and then linked to that?
But yeah notation for pullbacks and pushforwards is hard.
It has to convey so much state
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This bit makes sense:
(f^*)_Y(ΔY) = ((f_*)_A(ΔY'))'
so the pullback at A
, i.e. the pullback from Y (though that's not well defined since not all functions are monotonic?)
is equal to the the adjoint of the pushing forward at A
, the adjoint of of the output senstivity.
the fact that that is also equal to (f_*)_{A'}(ΔY)
is pretty magic.
Magical expodential symmetry? (I feel like i made the same suprised sounds for the same reason on your last PR)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if we just made a section for this in the docs, (maybe as internal notes or something)
and wrote the latex and then linked to that?
Hm, that's an idea. I'll consider it, potentially for a future PR.
This bit makes sense:
(f^*)_Y(ΔY) = ((f_*)_A(ΔY'))'
so the pullback atA
, i.e. the pullback from Y...is equal to the the adjoint of the pushing forward atA
, the adjoint of of the output senstivity.
Ah yes your description is correct (although it's the adjoint of the pushing forward of the adjoint). I just checked Lee, and this should be the right notation:
# (f^*)_Y(ΔY) = (f_*)_{A'}(ΔY) = ((f_*)_A(ΔY'))' | |
# (f^*)_A(ΔY) = (f_*)_{A'}(ΔY) = ((f_*)_A(ΔY'))' |
(though that's not well defined since not all functions are monotonic?)
I'm not sure what you mean by this.
Magical expodential symmetry? (I feel like i made the same suprised sounds for the same reason on your last PR)
It's still surprising to me. Although this property is general for all of the matrix functions defined in LinearAlgebra, not just exp
. It doesn't follow for all matrix functions though, just those whose convergent power series have real coefficients.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I started to comment on the formatting, but its a bit much.
I think just run https://github.com/domluna/JuliaFormatter.jl/ over it
format_file("matfun.jl", BlueStyle())
its pretty good.
I will review after that, since then I won't be spending time on the basic stuff.
Yeah, it's the formatting used by
Done! |
Codecov Report
@@ Coverage Diff @@
## master #351 +/- ##
===========================================
- Coverage 97.64% 86.43% -11.22%
===========================================
Files 18 19 +1
Lines 1231 1172 -59
===========================================
- Hits 1202 1013 -189
- Misses 29 159 +130
Continue to review full report at Codecov.
|
src/rulesets/LinearAlgebra/matfun.jl
Outdated
# NOTE: for matrix functions whose power series representation has real coefficients, | ||
# the pullback and pushforward are related by an adjoint. | ||
# Specifically, if the pushforward of f(A) is (f_*)_A(ΔA), then the pullback at Y=f(A) is | ||
# (f^*)_Y(ΔY) = (f_*)_{A'}(ΔY) = ((f_*)_A(ΔY'))' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unicode's missing subscripts make it extra hard.
Idea that might not be worth doing:
What if we just made a section for this in the docs, (maybe as internal notes or something)
and wrote the latex and then linked to that?
But yeah notation for pullbacks and pushforwards is hard.
It has to convey so much state
src/rulesets/LinearAlgebra/matfun.jl
Outdated
# NOTE: for matrix functions whose power series representation has real coefficients, | ||
# the pullback and pushforward are related by an adjoint. | ||
# Specifically, if the pushforward of f(A) is (f_*)_A(ΔA), then the pullback at Y=f(A) is | ||
# (f^*)_Y(ΔY) = (f_*)_{A'}(ΔY) = ((f_*)_A(ΔY'))' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This bit makes sense:
(f^*)_Y(ΔY) = ((f_*)_A(ΔY'))'
so the pullback at A
, i.e. the pullback from Y (though that's not well defined since not all functions are monotonic?)
is equal to the the adjoint of the pushing forward at A
, the adjoint of of the output senstivity.
the fact that that is also equal to (f_*)_{A'}(ΔY)
is pretty magic.
Magical expodential symmetry? (I feel like i made the same suprised sounds for the same reason on your last PR)
@oxinabox I added a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
888: Remove rules for matrix exponential r=DhairyaLGandhi a=sethaxen JuliaDiff/ChainRules.jl#351 added rules for the dense matrix exponential to ChainRules. This PR removes the corresponding adjoint from Zygote. Co-authored-by: Seth Axen <[email protected]>
Fixes #331
Because I plan to add a number of matrix functions, and they are long, I have added these rules to a new function
matfun.jl
.These rules unfortunately require quite a bit of code duplication from LinearAlgebra, but following discussion in JuliaLang/julia#5840, that's only unavoidable if we refactor these functions in LinearAlgebra to return their intermediates.