-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add @not_implemented
#335
Add @not_implemented
#335
Conversation
Codecov Report
@@ Coverage Diff @@
## master #335 +/- ##
==========================================
- Coverage 90.04% 89.50% -0.54%
==========================================
Files 13 14 +1
Lines 472 543 +71
==========================================
+ Hits 425 486 +61
- Misses 47 57 +10
Continue to review full report at Codecov.
|
I just updated the definition of the differential of julia> Zygote.gradient(1.0) do x
besseli(1, x)
end It still throws an error since |
With the reduced implementations, the MWE does not error anymore. When differentiating with respect to the order the julia> Zygote.gradient(1.0) do x
besseli(x, 3.0)
end
(ChainRulesCore.NotImplemented{Nothing, Nothing, Nothing}(nothing, nothing, nothing),) |
It seems a bit unfortunate that one has to define linear operations with |
Nice. I guess it could also be useful as a short term hack for e.g. Zygote ADing things that it can't handle but that commonly occur not on the actual path of concern (e.g. some function being called and logged). |
Re: errors. So we either need to overload linear operators to display that. I think overloading to propagate is fine. (Another reason why it is like AbsractZero. Maybe we should give a common super type of AbstractZero and this?) Also it's worth checking if hiding having fields behinds debug mode is worth it. |
I wonder if one could avoid the definitions of |
Thanks, that looks great! Just playing around locally a bit, will comment in a couple of hours |
Overall this looks great, thanks for implementing it! I share your concern about overloading linear operators for Would it be advantageous to only keep the Just so we don't forget, a few of things to do before we merge:
|
src/differentials/notimplemented.jl
Outdated
|
||
# Linear operators | ||
Base.adjoint(x::NotImplemented) = x | ||
Base.transpose(x::NotImplemented) = x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably don't need these two since conj
falls back to the identical AbstractDifferential
definition.
src/differentials/notimplemented.jl
Outdated
|
||
Base.:+(x::NotImplemented, ::Any) = x | ||
Base.:+(::Any, x::NotImplemented) = x | ||
Base.:+(x::NotImplemented, ::NotImplemented) = x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason we need these? Perhaps we could make it error?
Ah, one more thing: julia> @btime Zygote.gradient(x -> besseli(x, 1.0), 1.0)[1]
351.905 ns (3 allocations: 48 bytes)
NotImplemented(SpecialFunctions, #= /Users/mzgubic/JuliaEnvs/test/notimplemented/dev/SpecialFunctions/src/chainrules.jl:42 =#, besseli first derivative is not implemented)
julia> @btime Zygote.gradient(x -> besseli(x, 1.0), 1.0)[1]
373.468 ns (3 allocations: 48 bytes)
NotImplemented() It doesn't seem that the extra info is slowing it down, I would be in favour of using the extra information without debug mode in order to make it more user friendly |
AFAICT the current implementation of |
Nice, thanks for looking into this! Would you like me to perform some more benchmarks? It's great if we can provide better error messages without affecting performance 👍 |
The other extreme I guess is where one looks at julia> @btime @not_implemented("myerror")
0.047 ns (0 allocations: 0 bytes)
NotImplemented()
julia> @btime @not_implemented("myerror")
2.712 ns (0 allocations: 0 bytes)
NotImplemented(Main, #= REPL[6]:1 =#, myerror) But since it's such a small absolute number and the linear operators should constant fold it is probably fine? If you can think of more meaningful benchmarks I would love to see them |
I have marked this as assigned to @mzgubic to take over the line for final review. |
Co-authored-by: Lyndon White <[email protected]>
julia> g() = ChainRulesCore.@not_implemented("test")
g (generic function with 2 methods)
julia> ChainRulesCore.debug_mode() = false
julia> @btime g()
2.306 ns (0 allocations: 0 bytes)
NotImplemented(Main, #= REPL[43]:1 =#, test)
julia> ChainRulesCore.debug_mode() = true
julia> @btime g()
2.307 ns (0 allocations: 0 bytes)
NotImplemented(Main, #= REPL[43]:1 =#, test) Since I haven't found an example where the debugging information affects performance I remove the check for Edit: I missed that I had enabled debug mode in both benchmarks, I observe the same difference if performed correctly: julia> ChainRulesCore.debug_mode() = false
julia> g() = ChainRulesCore.@not_implemented("test")
g (generic function with 2 methods)
julia> @btime g()
0.018 ns (0 allocations: 0 bytes)
NotImplemented()
julia> ChainRulesCore.debug_mode() = true
julia> g() = ChainRulesCore.@not_implemented("test")
g (generic function with 1 method
julia> @btime g()
2.306 ns (0 allocations: 0 bytes)
NotImplemented(Main, #= REPL[10]:1 =#, test) |
docs/src/writing_good_rules.md
Outdated
|
||
One can use [`@not_implemented`](@ref) to mark missing differentials. This is helpful if | ||
the function has multiple inputs or outputs, and one has worked out analytically and | ||
implemented some but not all differentials. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a nitpick, but in markdown we usually have one sentence per line (rather than the character limit)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah OK, I'll fix it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Thank you for this wonderful contribution! I've added a few minor comments, but I think we are good to go. Let me know when you would like it merged. |
@devmotion has merge rights on the ChainRules project so can merge when ready |
Co-authored-by: Miha Zgubic <[email protected]>
Co-authored-by: Miha Zgubic <[email protected]>
Oh, ok, I assumed not since the PR is from a fork, please go ahead @devmotion How do i check who has merge rights? It came up a couple of times now that I wish I knew |
This is a first draft that addresses #334. It adds an unexported
NotImplemented
differential that is constructed with an exported@not_implemented
helper. IfChainRulesCore.debug_mode() == true
, it tracks debugging information such as module and location where the differential is missing and, possibly, additional user-provided information. It throws an error if one tries to perform computations with the differential (I just implemented some of the methods that are defined for other differentials but maybe the list should be shortened or extended).Examples:
I haven't checked yet if/how breaking it is for Zygote.