Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mlir: Func call reverse diff #2127

Merged
merged 7 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,140 @@ namespace {
#include "Implementations/FuncDerivatives.inc"
} // namespace

static std::optional<mlir::FunctionOpInterface>
getContainingFunction(Operation *orig) {
Operation *parent;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be a function interface instead?

while (parent = orig->getParentOp()) {
if (auto func = dyn_cast<mlir::FunctionOpInterface>(parent)) {
return std::optional(func);
}
}

return std::nullopt;
}

class AutoDiffCallRev
: public ReverseAutoDiffOpInterface::ExternalModel<AutoDiffCallRev,
func::CallOp> {
public:
LogicalResult createReverseModeAdjoint(Operation *orig, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
DerivativeMode mode = DerivativeMode::ReverseModeGradient;

SymbolTable symbolTable = SymbolTable::getNearestSymbolTable(orig);

func::CallOp callOp = cast<func::CallOp>(orig);

auto parent = getContainingFunction(orig);
if (parent.has_value() &&
callOp.getCallee() == parent.value().getNameAttr()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see this use, this is super fragile. For example it won’t work if you have a calls b calls a.

I think this is okay to merge as is but adding proper recursion support shouldn’t be bad. We have it already for forward enzymemlir which basically just requires having a cache in enzymelogic asking if we’ve already made this derivative function

Copy link
Contributor Author

@Pangoraw Pangoraw Oct 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah any cycle in the call graph will not be caught here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah enzymelogic should do this for forwards (and also llvm) already

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will look at implementing something like ForwardCachedFunctions for reverse in a follow-up

// TODO: Recursion chains
orig->emitError() << "could not emit adjoint of recursive call at: "
<< *orig << "\n";
return failure();
}

Operation *callee = symbolTable.lookup(callOp.getCallee());
auto fn = cast<FunctionOpInterface>(callee);

auto narg = orig->getNumOperands();
auto nret = orig->getNumResults();

std::vector<DIFFE_TYPE> RetActivity;
for (auto res : callOp.getResults()) {
RetActivity.push_back(
gutils->isConstantValue(res) ? DIFFE_TYPE::CONSTANT
: res.getType().cast<AutoDiffTypeInterface>().isMutable()
? DIFFE_TYPE::DUP_ARG
: DIFFE_TYPE::OUT_DIFF);
}

std::vector<DIFFE_TYPE> ArgActivity;
for (auto arg : callOp.getOperands()) {
ArgActivity.push_back(
gutils->isConstantValue(arg) ? DIFFE_TYPE::CONSTANT
: arg.getType().cast<AutoDiffTypeInterface>().isMutable()
? DIFFE_TYPE::DUP_ARG
: DIFFE_TYPE::OUT_DIFF);
}

if (llvm::any_of(ArgActivity,
[&](auto act) { return act == DIFFE_TYPE::DUP_ARG; }) ||
llvm::any_of(RetActivity,
[&](auto act) { return act == DIFFE_TYPE::DUP_ARG; })) {
// NOTE: this current approach fails when the function is not read only.
// i.e. it can modify its arguments.
orig->emitError() << "could not emit adjoint with mutable types in: "
<< *orig << "\n";
return failure();
}

std::vector<bool> volatile_args(narg, true);
std::vector<bool> returnShadow(narg, false);
std::vector<bool> returnPrimal(nret, false);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably default to always returning the primal, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really understand how the primal would be used here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we would only need for fused forward and reverse passes or handling of mutation, we can defer that for later here


auto type_args = gutils->TA.getAnalyzedTypeInfo(fn);

bool freeMemory = true;
size_t width = 1;

auto revFn = gutils->Logic.CreateReverseDiff(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Long term We probably need to do the same augmented forward pass and separate reverse pass that enzyme llvm does atm .

if we assume everything in the function is read only and only return active not duplicated results this is fine

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah right, the caching at the primal call site is not enough if the arguments are mutable. I could not make an example input with memrefs (something about invertPointerM). So I added an error if any of the result or arg of the initial call is mutable.

fn, RetActivity, ArgActivity, gutils->TA, returnPrimal, returnShadow,
mode, freeMemory, width, /*addedType*/ nullptr, type_args,
volatile_args, /*augmented*/ nullptr);

SmallVector<Value> revArguments;

for (auto cache : caches) {
revArguments.push_back(gutils->popCache(cache, builder));
}

for (auto result : callOp.getResults()) {
if (gutils->isConstantValue(result))
continue;
revArguments.push_back(gutils->diffe(result, builder));
}

auto revCallOp = builder.create<func::CallOp>(
orig->getLoc(), cast<func::FuncOp>(revFn), revArguments);

int revIndex = 0;
for (auto arg : callOp.getOperands()) {
if (gutils->isConstantValue(arg))
continue;
auto diffe = revCallOp.getResult(revIndex);
gutils->addToDiffe(arg, diffe, builder);
revIndex++;
}

return success();
}

SmallVector<Value> cacheValues(Operation *orig,
MGradientUtilsReverse *gutils) const {
SmallVector<Value> cachedArguments;

Operation *newOp = gutils->getNewFromOriginal(orig);
OpBuilder cacheBuilder(newOp);

for (auto arg : orig->getOperands()) {
Value cache = gutils->initAndPushCache(gutils->getNewFromOriginal(arg),
cacheBuilder);
cachedArguments.push_back(cache);
}

return cachedArguments;
}

void createShadowValues(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils) const {}
};

void mlir::enzyme::registerFuncDialectAutoDiffInterface(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *context, func::FuncDialect *) {
registerInterfaces(context);
func::CallOp::attachInterface<AutoDiffCallRev>(*context);
});
}
3 changes: 3 additions & 0 deletions enzyme/Enzyme/MLIR/enzymemlir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
Expand Down Expand Up @@ -72,6 +73,8 @@ int main(int argc, char **argv) {

mlir::registerenzymePasses();

mlir::func::registerInlinerExtension(registry);

// Register the standard passes we want.
mlir::registerCSEPass();
mlir::registerConvertAffineToStandardPass();
Expand Down
44 changes: 44 additions & 0 deletions enzyme/test/MLIR/ReverseMode/func.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// RUN: %eopt --enzyme-wrap="infn=main outfn= retTys=enzyme_active argTys=enzyme_active mode=ReverseModeCombined" --canonicalize --remove-unnecessary-enzyme-ops --enzyme-simplify-math %s | FileCheck %s

module {
func.func @square(%arg0: f32) -> f32 {
%0 = arith.mulf %arg0, %arg0 : f32
return %0 : f32
}

func.func @inactive(%arg0: f32, %arg1: i32) -> (f32, i32) {
%0 = arith.constant 42.0 : f32
%1 = arith.constant 42 : i32
return %0, %1 : f32, i32
}

func.func @main(%arg0: f32) -> f32 {
%0 = func.call @square(%arg0) : (f32) -> f32
%1 = arith.constant 10 : i32
%2:2 = func.call @inactive(%arg0, %1) : (f32, i32) -> (f32, i32)
%3 = arith.addf %0, %2#0 : f32
return %3 : f32
}
}

// CHECK: func.func @main(%arg0: f32, %arg1: f32) -> f32 {
// CHECK-NEXT: %c10_i32 = arith.constant 10 : i32
// CHECK-NEXT: %0 = call @square(%arg0) : (f32) -> f32
// CHECK-NEXT: %1:2 = call @inactive(%arg0, %c10_i32) : (f32, i32) -> (f32, i32)
// CHECK-NEXT: %2 = call @diffeinactive(%arg0, %c10_i32, %arg1) : (f32, i32, f32) -> f32
// CHECK-NEXT: %3 = call @diffesquare(%arg0, %arg1) : (f32, f32) -> f32
// CHECK-NEXT: %4 = arith.addf %2, %3 : f32
// CHECK-NEXT: return %4 : f32
// CHECK-NEXT: }

// CHECK: func.func private @diffeinactive(%arg0: f32, %arg1: i32, %arg2: f32) -> f32 {
// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: return %cst : f32
// CHECK-NEXT: }

// CHECK: func.func private @diffesquare(%arg0: f32, %arg1: f32) -> f32 {
// CHECK-NEXT: %0 = arith.mulf %arg1, %arg0 : f32
// CHECK-NEXT: %1 = arith.mulf %arg1, %arg0 : f32
// CHECK-NEXT: %2 = arith.addf %0, %1 : f32
// CHECK-NEXT: return %2 : f32
// CHECK-NEXT: }
Loading