-
Notifications
You must be signed in to change notification settings - Fork 109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
mlir: Func call reverse diff #2127
Changes from all commits
3a3f0aa
1c32aef
beb046d
6780ab7
2d0a262
1971a58
719930d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,9 +29,140 @@ namespace { | |
#include "Implementations/FuncDerivatives.inc" | ||
} // namespace | ||
|
||
static std::optional<mlir::FunctionOpInterface> | ||
getContainingFunction(Operation *orig) { | ||
Operation *parent; | ||
while (parent = orig->getParentOp()) { | ||
if (auto func = dyn_cast<mlir::FunctionOpInterface>(parent)) { | ||
return std::optional(func); | ||
} | ||
} | ||
|
||
return std::nullopt; | ||
} | ||
|
||
class AutoDiffCallRev | ||
: public ReverseAutoDiffOpInterface::ExternalModel<AutoDiffCallRev, | ||
func::CallOp> { | ||
public: | ||
LogicalResult createReverseModeAdjoint(Operation *orig, OpBuilder &builder, | ||
MGradientUtilsReverse *gutils, | ||
SmallVector<Value> caches) const { | ||
DerivativeMode mode = DerivativeMode::ReverseModeGradient; | ||
|
||
SymbolTable symbolTable = SymbolTable::getNearestSymbolTable(orig); | ||
|
||
func::CallOp callOp = cast<func::CallOp>(orig); | ||
|
||
auto parent = getContainingFunction(orig); | ||
if (parent.has_value() && | ||
callOp.getCallee() == parent.value().getNameAttr()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I see this use, this is super fragile. For example it won’t work if you have a calls b calls a. I think this is okay to merge as is but adding proper recursion support shouldn’t be bad. We have it already for forward enzymemlir which basically just requires having a cache in enzymelogic asking if we’ve already made this derivative function There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah any cycle in the call graph will not be caught here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah enzymelogic should do this for forwards (and also llvm) already There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will look at implementing something like ForwardCachedFunctions for reverse in a follow-up |
||
// TODO: Recursion chains | ||
orig->emitError() << "could not emit adjoint of recursive call at: " | ||
<< *orig << "\n"; | ||
return failure(); | ||
} | ||
|
||
Operation *callee = symbolTable.lookup(callOp.getCallee()); | ||
auto fn = cast<FunctionOpInterface>(callee); | ||
|
||
auto narg = orig->getNumOperands(); | ||
auto nret = orig->getNumResults(); | ||
|
||
std::vector<DIFFE_TYPE> RetActivity; | ||
for (auto res : callOp.getResults()) { | ||
RetActivity.push_back( | ||
gutils->isConstantValue(res) ? DIFFE_TYPE::CONSTANT | ||
: res.getType().cast<AutoDiffTypeInterface>().isMutable() | ||
? DIFFE_TYPE::DUP_ARG | ||
: DIFFE_TYPE::OUT_DIFF); | ||
} | ||
|
||
std::vector<DIFFE_TYPE> ArgActivity; | ||
for (auto arg : callOp.getOperands()) { | ||
ArgActivity.push_back( | ||
gutils->isConstantValue(arg) ? DIFFE_TYPE::CONSTANT | ||
: arg.getType().cast<AutoDiffTypeInterface>().isMutable() | ||
? DIFFE_TYPE::DUP_ARG | ||
: DIFFE_TYPE::OUT_DIFF); | ||
} | ||
|
||
if (llvm::any_of(ArgActivity, | ||
[&](auto act) { return act == DIFFE_TYPE::DUP_ARG; }) || | ||
llvm::any_of(RetActivity, | ||
[&](auto act) { return act == DIFFE_TYPE::DUP_ARG; })) { | ||
// NOTE: this current approach fails when the function is not read only. | ||
// i.e. it can modify its arguments. | ||
orig->emitError() << "could not emit adjoint with mutable types in: " | ||
<< *orig << "\n"; | ||
return failure(); | ||
} | ||
|
||
std::vector<bool> volatile_args(narg, true); | ||
std::vector<bool> returnShadow(narg, false); | ||
std::vector<bool> returnPrimal(nret, false); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should probably default to always returning the primal, no? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't really understand how the primal would be used here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah we would only need for fused forward and reverse passes or handling of mutation, we can defer that for later here |
||
|
||
auto type_args = gutils->TA.getAnalyzedTypeInfo(fn); | ||
|
||
bool freeMemory = true; | ||
size_t width = 1; | ||
|
||
auto revFn = gutils->Logic.CreateReverseDiff( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Long term We probably need to do the same augmented forward pass and separate reverse pass that enzyme llvm does atm . if we assume everything in the function is read only and only return active not duplicated results this is fine There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah right, the caching at the primal call site is not enough if the arguments are mutable. I could not make an example input with memrefs (something about invertPointerM). So I added an error if any of the result or arg of the initial call is mutable. |
||
fn, RetActivity, ArgActivity, gutils->TA, returnPrimal, returnShadow, | ||
mode, freeMemory, width, /*addedType*/ nullptr, type_args, | ||
volatile_args, /*augmented*/ nullptr); | ||
|
||
SmallVector<Value> revArguments; | ||
|
||
for (auto cache : caches) { | ||
revArguments.push_back(gutils->popCache(cache, builder)); | ||
} | ||
|
||
for (auto result : callOp.getResults()) { | ||
if (gutils->isConstantValue(result)) | ||
continue; | ||
revArguments.push_back(gutils->diffe(result, builder)); | ||
} | ||
|
||
auto revCallOp = builder.create<func::CallOp>( | ||
orig->getLoc(), cast<func::FuncOp>(revFn), revArguments); | ||
|
||
int revIndex = 0; | ||
for (auto arg : callOp.getOperands()) { | ||
if (gutils->isConstantValue(arg)) | ||
continue; | ||
auto diffe = revCallOp.getResult(revIndex); | ||
gutils->addToDiffe(arg, diffe, builder); | ||
revIndex++; | ||
} | ||
|
||
return success(); | ||
} | ||
|
||
SmallVector<Value> cacheValues(Operation *orig, | ||
MGradientUtilsReverse *gutils) const { | ||
SmallVector<Value> cachedArguments; | ||
|
||
Operation *newOp = gutils->getNewFromOriginal(orig); | ||
OpBuilder cacheBuilder(newOp); | ||
|
||
for (auto arg : orig->getOperands()) { | ||
Value cache = gutils->initAndPushCache(gutils->getNewFromOriginal(arg), | ||
cacheBuilder); | ||
cachedArguments.push_back(cache); | ||
} | ||
|
||
return cachedArguments; | ||
} | ||
|
||
void createShadowValues(Operation *op, OpBuilder &builder, | ||
MGradientUtilsReverse *gutils) const {} | ||
}; | ||
|
||
void mlir::enzyme::registerFuncDialectAutoDiffInterface( | ||
DialectRegistry ®istry) { | ||
registry.addExtension(+[](MLIRContext *context, func::FuncDialect *) { | ||
registerInterfaces(context); | ||
func::CallOp::attachInterface<AutoDiffCallRev>(*context); | ||
}); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
// RUN: %eopt --enzyme-wrap="infn=main outfn= retTys=enzyme_active argTys=enzyme_active mode=ReverseModeCombined" --canonicalize --remove-unnecessary-enzyme-ops --enzyme-simplify-math %s | FileCheck %s | ||
|
||
module { | ||
func.func @square(%arg0: f32) -> f32 { | ||
%0 = arith.mulf %arg0, %arg0 : f32 | ||
return %0 : f32 | ||
} | ||
|
||
func.func @inactive(%arg0: f32, %arg1: i32) -> (f32, i32) { | ||
%0 = arith.constant 42.0 : f32 | ||
%1 = arith.constant 42 : i32 | ||
return %0, %1 : f32, i32 | ||
} | ||
|
||
func.func @main(%arg0: f32) -> f32 { | ||
%0 = func.call @square(%arg0) : (f32) -> f32 | ||
%1 = arith.constant 10 : i32 | ||
%2:2 = func.call @inactive(%arg0, %1) : (f32, i32) -> (f32, i32) | ||
%3 = arith.addf %0, %2#0 : f32 | ||
return %3 : f32 | ||
} | ||
} | ||
|
||
// CHECK: func.func @main(%arg0: f32, %arg1: f32) -> f32 { | ||
// CHECK-NEXT: %c10_i32 = arith.constant 10 : i32 | ||
// CHECK-NEXT: %0 = call @square(%arg0) : (f32) -> f32 | ||
// CHECK-NEXT: %1:2 = call @inactive(%arg0, %c10_i32) : (f32, i32) -> (f32, i32) | ||
// CHECK-NEXT: %2 = call @diffeinactive(%arg0, %c10_i32, %arg1) : (f32, i32, f32) -> f32 | ||
// CHECK-NEXT: %3 = call @diffesquare(%arg0, %arg1) : (f32, f32) -> f32 | ||
// CHECK-NEXT: %4 = arith.addf %2, %3 : f32 | ||
// CHECK-NEXT: return %4 : f32 | ||
// CHECK-NEXT: } | ||
|
||
// CHECK: func.func private @diffeinactive(%arg0: f32, %arg1: i32, %arg2: f32) -> f32 { | ||
// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 | ||
// CHECK-NEXT: return %cst : f32 | ||
// CHECK-NEXT: } | ||
|
||
// CHECK: func.func private @diffesquare(%arg0: f32, %arg1: f32) -> f32 { | ||
// CHECK-NEXT: %0 = arith.mulf %arg1, %arg0 : f32 | ||
// CHECK-NEXT: %1 = arith.mulf %arg1, %arg0 : f32 | ||
// CHECK-NEXT: %2 = arith.addf %0, %1 : f32 | ||
// CHECK-NEXT: return %2 : f32 | ||
// CHECK-NEXT: } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be a function interface instead?