-
Notifications
You must be signed in to change notification settings - Fork 5
Conversation
to do the inplace I would like to have but i don't need it since can just overload |
src/sensitivities/chainrules.jl
Outdated
end | ||
|
||
"like `ExprTools.signature` but on a signature type-tuple, not a Method" | ||
function build_def(sig) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possibly this should move into ExprTools.
Needs invenia/ExprTools.jl#12 |
Probably what this should do is look at the method table and check if the simple unionized overload would eclipse any in the wrong way (what is that? I need to think carefully). Though maybe that check would take longer than the extra processing time to generate and load all of them |
0993838
to
a335e03
Compare
Co-authored-by: Curtis Vogt <[email protected]>
Co-authored-by: Curtis Vogt <[email protected]>
… passing in a preprocess output
Co-authored-by: Eric Davies <[email protected]>
…nRules.jl correctly
I thought i was done, Also I realised the docs wouldn't build anymore. Should be all sorted now |
Supercedes #178
Follows https://www.juliadiff.org/ChainRulesCore.jl/dev/autodiff/operator_overloading.html
(which will probably get updates during this based on practical learnings)
What this PR does:
Nabla.update!
swapped out for the very similaradd!!
so it would work forInplaceableThunk
sForwardDiff.derivative
but for multiargument things it still uses the Dual numbers directly.node_type
tranform that is likeunionise_type
transform but without making the unionpreprocess
not receive its inputs pre-unboxe-d, but have the default fallback unbox them and recall processVarArg{T, N} where N
add support for SpecialFunction 0.10Drops support for SpecialFunctions 0.9.
lgamma
/loggamma
as they don't both exist in nondeprecated form in same versionThings i suggest leaving for potential future PRs
(but that reviewers might disagree with)
:no_N
in src/conde_tranformations/utils.jl. Which doesn’t seem to ever be hitPair{Node}
returnNode{Pair}
for consistency and sodiagm
will hit rules we define in ChainRules.jlNotes on implementation
The core logic is to use of the Operator Overloading interface of ChainRules, which lets you register a hook that is triggered passing in a type- type representing the signature of every primal function that ChainRulesCore has an overload of
rrule
for.This hook is the
generate_overload
function.This filters out a bunch of things.
It then uses ExprTools to get a AST for function defination that would be suitable for overloading the primal function (as an overloading based AD like Nabla does).
From that it generates: overloads for that primal but with in turn each argument swapped out for the matching node (this is why
node_type
was added to the code tranformation functions).And earlier version use
unionise_type
instead of swapping it out, but for things with primal type ofAny
(which shows up fornondifferentiable_rule
), this just resulted inUnion{Node{Any}, Any}
which simplifies toAny
. Which mean we were overwriting the original primal definition which will break everything.The key thing these generated primal overloads do is create a
Branch
that stores the pullback.We then generate a method for
preprocess
which invokes that pullback, computing the partials for all the arguments.And we generate a method for
∇
that just talkes that partial computed by preprocess and return the right one for the specifiedArg{N}
.Things to do before Review
Things for reviewers to consider:
src/sensitivities/chainrules.jl
filemap
.Pair{<:Node, <:Node}
, rather thanNode{<:Pair}
rrule
is added to ChainRules for something Nabla has it will cause Nabla to break due to ambiguity.Arg{1}
andArg{2}
etc cases. That would remove ambiguities i think.