From f0f64db0ca8022e185f1299ab0afae72be46a555 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 29 May 2022 14:24:53 -0400 Subject: [PATCH] Add forward mode fadd and free debug (#674) --- enzyme/Enzyme/AdjointGenerator.h | 53 +++++++++++++++++++- enzyme/Enzyme/GradientUtils.cpp | 2 +- enzyme/test/Enzyme/ForwardMode/atomicfadd.ll | 29 +++++++++++ 3 files changed, 82 insertions(+), 2 deletions(-) create mode 100644 enzyme/test/Enzyme/ForwardMode/atomicfadd.ll diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index dbc5dcb077d95..62798681f7737 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -778,6 +778,56 @@ class AdjointGenerator } void visitAtomicRMWInst(llvm::AtomicRMWInst &I) { + if (Mode == DerivativeMode::ForwardMode) { + IRBuilder<> BuilderZ(&I); + getForwardBuilder(BuilderZ); + switch (I.getOperation()) { + case AtomicRMWInst::FAdd: + case AtomicRMWInst::FSub: { + auto rule = [&](Value *ptr, Value *dif) -> Value * { + if (!gutils->isConstantInstruction(&I)) { + assert(ptr); + AtomicRMWInst *rmw = nullptr; +#if LLVM_VERSION_MAJOR >= 13 + rmw = BuilderZ.CreateAtomicRMW(I.getOperation(), ptr, dif, + I.getAlign(), I.getOrdering(), + I.getSyncScopeID()); +#elif LLVM_VERSION_MAJOR >= 11 + rmw = BuilderZ.CreateAtomicRMW(I.getOperation(), ptr, dif, + I.getOrdering(), I.getSyncScopeID()); + rmw->setAlignment(I.getAlign()); +#else + rmw = BuilderZ.CreateAtomicRMW( + I.getOperation(), ptr, dif, I.getOrdering(), + I.getSyncScopeID()); +#endif + rmw->setVolatile(I.isVolatile()); + if (gutils->isConstantValue(&I)) + return Constant::getNullValue(dif->getType()); + else + return rmw; + } else { + assert(gutils->isConstantValue(&I)); + return Constant::getNullValue(dif->getType()); + } + }; + + Value *diff = applyChainRule( + I.getType(), BuilderZ, rule, + gutils->isConstantValue(I.getPointerOperand()) + ? nullptr + : gutils->invertPointerM(I.getPointerOperand(), BuilderZ), + gutils->isConstantValue(I.getValOperand()) + ? Constant::getNullValue(I.getType()) + : gutils->invertPointerM(I.getValOperand(), BuilderZ)); + if (!gutils->isConstantValue(&I)) + setDiffe(&I, diff, BuilderZ); + return; + } + default: + break; + } + } if (!gutils->isConstantInstruction(&I) || !gutils->isConstantValue(&I)) { TR.dump(); llvm::errs() << "oldFunc: " << *gutils->newFunc << "\n"; @@ -11083,7 +11133,8 @@ class AdjointGenerator auto rule = [&args](Value *tofree) { args.push_back(tofree); }; applyChainRule(Builder2, rule, tofree); - Builder2.CreateCall(free->getFunctionType(), free, args); + auto frees = Builder2.CreateCall(free->getFunctionType(), free, args); + frees->setDebugLoc(gutils->getNewFromOriginal(orig->getDebugLoc())); return; } diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 88a0797ab6ea7..14791690ba6f8 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -3821,7 +3821,7 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, } } - if (auto arg = dyn_cast(oval)) { + if (isa(oval)) { auto rule = [&oval]() { return oval; }; return applyChainRule(oval->getType(), BuilderM, rule); } diff --git a/enzyme/test/Enzyme/ForwardMode/atomicfadd.ll b/enzyme/test/Enzyme/ForwardMode/atomicfadd.ll new file mode 100644 index 0000000000000..feb752ed537b2 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/atomicfadd.ll @@ -0,0 +1,29 @@ +; RUN: if [ %llvmver -ge 9 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -simplifycfg -S | FileCheck %s; fi + +; Function Attrs: norecurse nounwind readonly uwtable +define dso_local double @sum(double* nocapture %n, double %x) #0 { +entry: + %res = atomicrmw fadd double* %n, double %x monotonic + ret double %res +} + +; Function Attrs: nounwind uwtable +define dso_local void @dsum(double* %x, double* %xp, double %n) local_unnamed_addr #1 { +entry: + %0 = tail call double (double (double*, double)*, ...) @__enzyme_fwddiff(double (double*, double)* nonnull @sum, double* %x, double* %xp, double %n, double 1.000000e+00) + ret void +} + +; Function Attrs: nounwind +declare double @__enzyme_fwddiff(double (double*, double)*, ...) #2 + +attributes #0 = { norecurse nounwind readonly uwtable } +attributes #1 = { nounwind uwtable } +attributes #2 = { nounwind } + +; CHECK: define internal double @fwddiffesum(double* nocapture %n, double* nocapture %"n'", double %x, double %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %res = atomicrmw fadd double* %n, double %x monotonic +; CHECK-NEXT: %0 = atomicrmw fadd double* %"n'", double %"x'" monotonic +; CHECK-NEXT: ret double %0 +; CHECK-NEXT: }