-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
chain rules for DCT #273
base: master
Are you sure you want to change the base?
chain rules for DCT #273
Conversation
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## master #273 +/- ##
==========================================
- Coverage 73.08% 70.70% -2.38%
==========================================
Files 5 6 +1
Lines 535 553 +18
==========================================
Hits 391 391
- Misses 144 162 +18
☔ View full report in Codecov by Sentry. |
This error is only happening with the MKL provider. With MKL, FFTW.jl doesn't even compile on my machine. Could be due to 008bc5b? test_frule: idct on Array{Float64, 3},Int64: Error During Test at /home/runner/.julia/packages/ChainRulesTestUtils/C9L2i/src/testers.jl:123
Got exception outside of a @test
FFTW could not create plan
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] macro expansion
@ FFTW ~/work/FFTW.jl/FFTW.jl/src/fft.jl:722 [inlined] |
@devmotion, could you please review this? LMK if you want me to remove the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain the idea of the PR? The design goal was to define the differentiation rules in AbstractFFTs via adjoint plans. Downstream packages were supposed to implement this new adjoint interface.
Only the MKL tests are failing. Is there a regression with MKL? |
More concretely, #249 outlines the intended approach. FFTW was not supposed to define custom rules. |
This PR defines Chain Rules for |
With #249 , gradient computation would error for DCT/IDCT. This is because julia> using FFTW, Zygote
julia>
julia> using LinearAlgebra, FFTW, Zygote
julia> x = rand(4)
4-element Vector{Float64}:
0.8692266334693106
0.6938635624794242
0.552208368655668
0.9197557963740512
julia> f(x) = x |> dct |> idct |> norm
f (generic function with 1 method)
julia> f(x)
1.5452787421840921
julia> Zygote.gradient(f, x)
ERROR: Compiling Tuple{Type{FFTW.r2rFFTWPlan{Float64, Any, false, 1}}, Vector{Float64}, FFTW.FakeArray{Float64, 1}, UnitRange{Int64}, Int64, UInt32, Float64}: try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, somehow I missed yesterday that the PR does not add rules for plans but for the dct
, idct
and r2r
functions 🤦
As long as they/their interface is not moved to AbstractFFTs, rules should be defined here 👍
ext/FFTWChainRulesCoreExt.jl
Outdated
|
||
# R2R | ||
|
||
function ChainRulesCore.frule(Δ, ::typeof(r2r), x::AbstractArray, region...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems the rrule
for r2r
is missing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The R2R transforms are not unitary. There is some scaling involved that depends on the kind of R2R transform. Because it looks like an involved task, I chose to skip that for now. I am happy to look into that in a separate PR
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: David Widmann <[email protected]>
@devmotion I've addressed all your comments. LMK if you have more questions :D |
@devmotion ping :) |
address #272