Skip to content

Commit

Permalink
Add forward mode fadd and free debug (rust-lang#674)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored May 29, 2022
1 parent bb70046 commit f0f64db
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 2 deletions.
53 changes: 52 additions & 1 deletion enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,56 @@ class AdjointGenerator
}

void visitAtomicRMWInst(llvm::AtomicRMWInst &I) {
if (Mode == DerivativeMode::ForwardMode) {
IRBuilder<> BuilderZ(&I);
getForwardBuilder(BuilderZ);
switch (I.getOperation()) {
case AtomicRMWInst::FAdd:
case AtomicRMWInst::FSub: {
auto rule = [&](Value *ptr, Value *dif) -> Value * {
if (!gutils->isConstantInstruction(&I)) {
assert(ptr);
AtomicRMWInst *rmw = nullptr;
#if LLVM_VERSION_MAJOR >= 13
rmw = BuilderZ.CreateAtomicRMW(I.getOperation(), ptr, dif,
I.getAlign(), I.getOrdering(),
I.getSyncScopeID());
#elif LLVM_VERSION_MAJOR >= 11
rmw = BuilderZ.CreateAtomicRMW(I.getOperation(), ptr, dif,
I.getOrdering(), I.getSyncScopeID());
rmw->setAlignment(I.getAlign());
#else
rmw = BuilderZ.CreateAtomicRMW(
I.getOperation(), ptr, dif, I.getOrdering(),
I.getSyncScopeID());
#endif
rmw->setVolatile(I.isVolatile());
if (gutils->isConstantValue(&I))
return Constant::getNullValue(dif->getType());
else
return rmw;
} else {
assert(gutils->isConstantValue(&I));
return Constant::getNullValue(dif->getType());
}
};

Value *diff = applyChainRule(
I.getType(), BuilderZ, rule,
gutils->isConstantValue(I.getPointerOperand())
? nullptr
: gutils->invertPointerM(I.getPointerOperand(), BuilderZ),
gutils->isConstantValue(I.getValOperand())
? Constant::getNullValue(I.getType())
: gutils->invertPointerM(I.getValOperand(), BuilderZ));
if (!gutils->isConstantValue(&I))
setDiffe(&I, diff, BuilderZ);
return;
}
default:
break;
}
}
if (!gutils->isConstantInstruction(&I) || !gutils->isConstantValue(&I)) {
TR.dump();
llvm::errs() << "oldFunc: " << *gutils->newFunc << "\n";
Expand Down Expand Up @@ -11083,7 +11133,8 @@ class AdjointGenerator
auto rule = [&args](Value *tofree) { args.push_back(tofree); };
applyChainRule(Builder2, rule, tofree);

Builder2.CreateCall(free->getFunctionType(), free, args);
auto frees = Builder2.CreateCall(free->getFunctionType(), free, args);
frees->setDebugLoc(gutils->getNewFromOriginal(orig->getDebugLoc()));

return;
}
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3821,7 +3821,7 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM,
}
}

if (auto arg = dyn_cast<ConstantExpr>(oval)) {
if (isa<ConstantExpr>(oval)) {
auto rule = [&oval]() { return oval; };
return applyChainRule(oval->getType(), BuilderM, rule);
}
Expand Down
29 changes: 29 additions & 0 deletions enzyme/test/Enzyme/ForwardMode/atomicfadd.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
; RUN: if [ %llvmver -ge 9 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -simplifycfg -S | FileCheck %s; fi

; Function Attrs: norecurse nounwind readonly uwtable
define dso_local double @sum(double* nocapture %n, double %x) #0 {
entry:
%res = atomicrmw fadd double* %n, double %x monotonic
ret double %res
}

; Function Attrs: nounwind uwtable
define dso_local void @dsum(double* %x, double* %xp, double %n) local_unnamed_addr #1 {
entry:
%0 = tail call double (double (double*, double)*, ...) @__enzyme_fwddiff(double (double*, double)* nonnull @sum, double* %x, double* %xp, double %n, double 1.000000e+00)
ret void
}

; Function Attrs: nounwind
declare double @__enzyme_fwddiff(double (double*, double)*, ...) #2

attributes #0 = { norecurse nounwind readonly uwtable }
attributes #1 = { nounwind uwtable }
attributes #2 = { nounwind }

; CHECK: define internal double @fwddiffesum(double* nocapture %n, double* nocapture %"n'", double %x, double %"x'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %res = atomicrmw fadd double* %n, double %x monotonic
; CHECK-NEXT: %0 = atomicrmw fadd double* %"n'", double %"x'" monotonic
; CHECK-NEXT: ret double %0
; CHECK-NEXT: }

0 comments on commit f0f64db

Please sign in to comment.