-
Notifications
You must be signed in to change notification settings - Fork 109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
mlir: Func call reverse diff #2127
Conversation
enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp
Outdated
Show resolved
Hide resolved
enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp
Outdated
Show resolved
Hide resolved
|
||
std::vector<bool> volatile_args(narg, true); | ||
std::vector<bool> returnShadow(narg, false); | ||
std::vector<bool> returnPrimal(nret, false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably default to always returning the primal, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really understand how the primal would be used here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah we would only need for fused forward and reverse passes or handling of mutation, we can defer that for later here
bool freeMemory = true; | ||
size_t width = 1; | ||
|
||
auto revFn = gutils->Logic.CreateReverseDiff( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Long term We probably need to do the same augmented forward pass and separate reverse pass that enzyme llvm does atm .
if we assume everything in the function is read only and only return active not duplicated results this is fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah right, the caching at the primal call site is not enough if the arguments are mutable. I could not make an example input with memrefs (something about invertPointerM). So I added an error if any of the result or arg of the initial call is mutable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM but some misc fix ups and also the assumptions of the current implementation should throw errors if not met (we have a isReadOnly for example)
@@ -29,9 +29,139 @@ namespace { | |||
#include "Implementations/FuncDerivatives.inc" | |||
} // namespace | |||
|
|||
static std::optional<func::FuncOp> getContainingFunction(Operation *orig) { | |||
Operation *parent; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be a function interface instead?
|
||
auto parent = getContainingFunction(orig); | ||
if (parent.has_value() && | ||
callOp.getCallee() == parent.value().getNameAttr()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see this use, this is super fragile. For example it won’t work if you have a calls b calls a.
I think this is okay to merge as is but adding proper recursion support shouldn’t be bad. We have it already for forward enzymemlir which basically just requires having a cache in enzymelogic asking if we’ve already made this derivative function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah any cycle in the call graph will not be caught here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah enzymelogic should do this for forwards (and also llvm) already
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will look at implementing something like ForwardCachedFunctions for reverse in a follow-up
* main: (49 commits) Fix iv of constant (#2141) Update benchmarks (#2035) Implement tgamma derivative (#2140) tgamma error improvement (#2139) Improve cache index error message (#2138) Fixes warnings and adds missing header guards (#2124) mlir: cache and reuse reverse funcs (#2133) mlir: implement forward mode for func.call (#2134) mlir: Func call reverse diff (#2127) Update build_tarballs.jl Fix combined temp cache for reverse (#2131) Improve runtime activity err message (#2132) Fix undef value storage (#2129) Adapt to const tblgen (#2128) Add gcloaded TT (#2125) Fix blas decl updater indexing (#2123) Add header files to ClangEnzyme target (#2062) Improve unknown function error messages (#2120) Fix handle sync (#2122) Support more Julia 1.11 functions (#2121) ...
* main: (49 commits) Fix iv of constant (#2141) Update benchmarks (#2035) Implement tgamma derivative (#2140) tgamma error improvement (#2139) Improve cache index error message (#2138) Fixes warnings and adds missing header guards (#2124) mlir: cache and reuse reverse funcs (#2133) mlir: implement forward mode for func.call (#2134) mlir: Func call reverse diff (#2127) Update build_tarballs.jl Fix combined temp cache for reverse (#2131) Improve runtime activity err message (#2132) Fix undef value storage (#2129) Adapt to const tblgen (#2128) Add gcloaded TT (#2125) Fix blas decl updater indexing (#2123) Add header files to ClangEnzyme target (#2062) Improve unknown function error messages (#2120) Fix handle sync (#2122) Support more Julia 1.11 functions (#2121) ...
prompted by EnzymeAD/Reactant.jl#171 (comment).
We can now run enzyme-batch -> enzyme-wrap without inlining.