Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add @not_implemented #335

Merged
merged 29 commits into from
Apr 20, 2021
Merged

Conversation

devmotion
Copy link
Member

This is a first draft that addresses #334. It adds an unexported NotImplemented differential that is constructed with an exported @not_implemented helper. If ChainRulesCore.debug_mode() == true, it tracks debugging information such as module and location where the differential is missing and, possibly, additional user-provided information. It throws an error if one tries to perform computations with the differential (I just implemented some of the methods that are defined for other differentials but maybe the list should be shortened or extended).

Examples:

julia> ChainRulesCore.debug_mode() = false # default

julia> zero(ChainRulesCore.@not_implemented())
ERROR: differential not implemented
Stacktrace:
...

julia> zero(ChainRulesCore.@not_implemented("more info"))
ERROR: differential not implemented
Stacktrace:
...

julia> ChainRulesCore.debug_mode() = true

julia> zero(ChainRulesCore.@not_implemented())
ERROR: differential not implemented @ Main #= REPL[18]:1 =#
Stacktrace:
...

julia> zero(ChainRulesCore.@not_implemented("more info"))
ERROR: differential not implemented @ Main #= REPL[20]:1 =#
Info: more info
Stacktrace:
...

I haven't checked yet if/how breaking it is for Zygote.

@devmotion devmotion marked this pull request as draft April 16, 2021 19:49
@codecov-commenter
Copy link

codecov-commenter commented Apr 16, 2021

Codecov Report

Merging #335 (17935fe) into master (8315231) will decrease coverage by 0.53%.
The diff coverage is 85.71%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #335      +/-   ##
==========================================
- Coverage   90.04%   89.50%   -0.54%     
==========================================
  Files          13       14       +1     
  Lines         472      543      +71     
==========================================
+ Hits          425      486      +61     
- Misses         47       57      +10     
Impacted Files Coverage Δ
src/ChainRulesCore.jl 100.00% <ø> (ø)
src/differentials/notimplemented.jl 67.85% <67.85%> (ø)
src/differential_arithmetic.jl 96.72% <97.61%> (+0.47%) ⬆️
src/differentials/composite.jl 82.20% <0.00%> (+0.15%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 8315231...17935fe. Read the comment docs.

@devmotion
Copy link
Member Author

I just updated the definition of the differential of SpecialFunctions.besseli locally and checked the MWE in FluxML/Zygote.jl#873:

julia> Zygote.gradient(1.0) do x
           besseli(1, x)
       end

It still throws an error since conj is applied in the pullback defined by @scalar_rule (the full computation is conj(NotImplemented(...)) * var"..."). I assume this means that one should rather use the default conj(x::AbstractDifferential) = x instead of throwing an error.

@devmotion
Copy link
Member Author

With the reduced implementations, the MWE does not error anymore. When differentiating with respect to the order the NotImplemented is returned, e.g., one obtains

julia> Zygote.gradient(1.0) do x
           besseli(x, 3.0)
       end
(ChainRulesCore.NotImplemented{Nothing, Nothing, Nothing}(nothing, nothing, nothing),)

@devmotion
Copy link
Member Author

It seems a bit unfortunate that one has to define linear operations with NotImplemented since otherwise pullbacks throw an error if one wants that computations with it throw an error. The question is, is it important to throw an error for some computations or should just most operations propagate it?

@oxinabox
Copy link
Member

Nice.
Let's include strong guidance in the docstrings about not using this if the code could be AD'ed through. Only good if the code would error if tried to AD.
And point out it is most useful for if the function is had multiple returns and you have worked out analysically some of but not all of them.

I guess it could also be useful as a short term hack for e.g. Zygote ADing things that it can't handle but that commonly occur not on the actual path of concern (e.g. some function being called and logged).
Since it is more correct than marking those nondifferentiable.

@oxinabox
Copy link
Member

Re: errors.
The main problem would be that MethodErrors being thrown
And we want to use a custom NotImplementedException that displays the metadata we gave it.
Or that if that wasn't captured displays a .message saying to rerun with debug mode on to identify the source of this.

So we either need to overload linear operators to display that.
Or maybe we get that into into the MethodError by putting it as tyoeparams. Less pretty but gets it across.

I think overloading to propagate is fine. (Another reason why it is like AbsractZero. Maybe we should give a common super type of AbstractZero and this?)
That would mean it would stick around until you tried to add it to a primal e.g. in a optimizer (this is the behavior I wanted for DoesNotExist originally and still kinda do)
And that could be OK since it doesn't matter how much work you do with this, as long as you never add it to a primal, you never use it. You never get a wrong answer

Also it's worth checking if hiding having fields behinds debug mode is worth it.
Does it allocate if we store these, or does Julia do something with compile time literals.

@devmotion
Copy link
Member Author

I wonder if one could avoid the definitions of conj etc. by modifying @scalar_rule, e.g., by defining a single function as suggested in #309 for the propagation expression and specializing it for NotImplemented.

@mzgubic
Copy link
Member

mzgubic commented Apr 19, 2021

Thanks, that looks great! Just playing around locally a bit, will comment in a couple of hours

@mzgubic
Copy link
Member

mzgubic commented Apr 19, 2021

Overall this looks great, thanks for implementing it!

I share your concern about overloading linear operators for NotImplemented. On one hand it is pretty harmless since as @oxinabox points out we never get a wrong result, but it does feel wrong. I will take a look into #309 to see if we can use that instead (my preferred option), but I don't think that should block this PR.

Would it be advantageous to only keep the * operator defined (needed for @scalar_rule to work), and define errors for common others (+, -, / adjoint transpose etc.) so that most users see the nice error message rather than a MethodError?

Just so we don't forget, a few of things to do before we merge:

  • version bump
  • tests
  • documentation (explaining when to use this macro, should probably go in "Writing Good Rules" section)


# Linear operators
Base.adjoint(x::NotImplemented) = x
Base.transpose(x::NotImplemented) = x
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably don't need these two since conj falls back to the identical AbstractDifferential definition.


Base.:+(x::NotImplemented, ::Any) = x
Base.:+(::Any, x::NotImplemented) = x
Base.:+(x::NotImplemented, ::NotImplemented) = x
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we need these? Perhaps we could make it error?

src/differentials/notimplemented.jl Outdated Show resolved Hide resolved
@mzgubic
Copy link
Member

mzgubic commented Apr 19, 2021

Ah, one more thing:

julia> @btime  Zygote.gradient(x -> besseli(x, 1.0), 1.0)[1]
  351.905 ns (3 allocations: 48 bytes)
NotImplemented(SpecialFunctions, #= /Users/mzgubic/JuliaEnvs/test/notimplemented/dev/SpecialFunctions/src/chainrules.jl:42 =#, besseli first derivative is not implemented)

julia>  @btime  Zygote.gradient(x -> besseli(x, 1.0), 1.0)[1]
  373.468 ns (3 allocations: 48 bytes)
NotImplemented()

It doesn't seem that the extra info is slowing it down, I would be in favour of using the extra information without debug mode in order to make it more user friendly

@devmotion
Copy link
Member Author

Would it be advantageous to only keep the * operator defined (needed for @scalar_rule to work), and define errors for common others (+, -, / adjoint transpose etc.) so that most users see the nice error message rather than a MethodError?

AFAICT the current implementation of @scalar_rule requires us to define at least conj(::NotImplemented) (already provided by the default definition for AbstractDifferential), *(::NotImplemented, ::Any), and muladd(::NotImplemented, ::Any, ::Any), muladd(::Any, ::Any, ::NotImplemented), and muladd(::NotImplemented, ::Any, ::NotImplemented). But I'll check this more carefully.

@devmotion
Copy link
Member Author

It doesn't seem that the extra info is slowing it down, I would be in favour of using the extra information without debug mode in order to make it more user friendly

Nice, thanks for looking into this! Would you like me to perform some more benchmarks? It's great if we can provide better error messages without affecting performance 👍

@mzgubic
Copy link
Member

mzgubic commented Apr 19, 2021

The other extreme I guess is where one looks at

julia>  @btime  @not_implemented("myerror")
  0.047 ns (0 allocations: 0 bytes)
NotImplemented()

julia> @btime  @not_implemented("myerror")
  2.712 ns (0 allocations: 0 bytes)
NotImplemented(Main, #= REPL[6]:1 =#, myerror)

But since it's such a small absolute number and the linear operators should constant fold it is probably fine? If you can think of more meaningful benchmarks I would love to see them

@oxinabox
Copy link
Member

I have marked this as assigned to @mzgubic to take over the line for final review.
since it doesn't need two of us commenting on it.
For what its worth i don't have any problem with defining conj or any other linear operator on top of NotImplemented.
Its a bit like missing propagation.
But better since we know where it came from.

Co-authored-by: Lyndon White <[email protected]>
@devmotion
Copy link
Member Author

devmotion commented Apr 19, 2021

The other extreme I guess is where one looks at

julia>  @btime  @not_implemented("myerror")
  0.047 ns (0 allocations: 0 bytes)
NotImplemented()

julia> @btime  @not_implemented("myerror")
  2.712 ns (0 allocations: 0 bytes)
NotImplemented(Main, #= REPL[6]:1 =#, myerror)

But since it's such a small absolute number and the linear operators should constant fold it is probably fine? If you can think of more meaningful benchmarks I would love to see them

The timings are basically the same if one wraps it in a function:

julia> g() = ChainRulesCore.@not_implemented("test")
g (generic function with 2 methods)

julia> ChainRulesCore.debug_mode() = false

julia> @btime g()
  2.306 ns (0 allocations: 0 bytes)
NotImplemented(Main, #= REPL[43]:1 =#, test)

julia> ChainRulesCore.debug_mode() = true

julia> @btime g()
  2.307 ns (0 allocations: 0 bytes)
NotImplemented(Main, #= REPL[43]:1 =#, test)

Since I haven't found an example where the debugging information affects performance I remove the check for debug_mode().

Edit: I missed that I had enabled debug mode in both benchmarks, I observe the same difference if performed correctly:

julia> ChainRulesCore.debug_mode() = false

julia> g() = ChainRulesCore.@not_implemented("test")
g (generic function with 2 methods)

julia> @btime g()
  0.018 ns (0 allocations: 0 bytes)
NotImplemented()

julia> ChainRulesCore.debug_mode() = true

julia> g() = ChainRulesCore.@not_implemented("test")
g (generic function with 1 method

julia> @btime g()
  2.306 ns (0 allocations: 0 bytes)
NotImplemented(Main, #= REPL[10]:1 =#, test)

@devmotion devmotion changed the title [WIP] Add @not_implemented Add @not_implemented Apr 20, 2021
@devmotion devmotion marked this pull request as ready for review April 20, 2021 07:58
docs/src/writing_good_rules.md Outdated Show resolved Hide resolved

One can use [`@not_implemented`](@ref) to mark missing differentials. This is helpful if
the function has multiple inputs or outputs, and one has worked out analytically and
implemented some but not all differentials.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a nitpick, but in markdown we usually have one sentence per line (rather than the character limit)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah OK, I'll fix it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

src/differential_arithmetic.jl Show resolved Hide resolved
test/differentials/notimplemented.jl Outdated Show resolved Hide resolved
test/differentials/notimplemented.jl Outdated Show resolved Hide resolved
@mzgubic
Copy link
Member

mzgubic commented Apr 20, 2021

Thank you for this wonderful contribution! I've added a few minor comments, but I think we are good to go. Let me know when you would like it merged.

@oxinabox
Copy link
Member

Let me know when you would like it merged.

@devmotion has merge rights on the ChainRules project so can merge when ready

devmotion and others added 2 commits April 20, 2021 12:31
@mzgubic
Copy link
Member

mzgubic commented Apr 20, 2021

Let me know when you would like it merged.

@devmotion has merge rights on the ChainRules project so can merge when ready

Oh, ok, I assumed not since the PR is from a fork, please go ahead @devmotion

How do i check who has merge rights? It came up a couple of times now that I wish I knew

@devmotion
Copy link
Member Author

@oxinabox @mzgubic Anything else that you would like to be addressed?

@devmotion devmotion merged commit 208c85e into JuliaDiff:master Apr 20, 2021
@devmotion devmotion deleted the dw/notimplemented branch April 20, 2021 12:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants