Skip to content

Commit

Permalink
Enable first stage of postprocessing optimization (rust-lang#497)
Browse files Browse the repository at this point in the history
* Actually run postopt

* Fix OpenMP gradient

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
wsmoses and Ubuntu authored Feb 7, 2022
1 parent 8964ab6 commit a5b2aa3
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 20 deletions.
35 changes: 28 additions & 7 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Passes/PassBuilder.h"

#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
Expand Down Expand Up @@ -1629,7 +1630,6 @@ class Enzyme : public ModulePass {
// TODO consider enabling when attributor does not delete
// dead internal functions, which invalidates Enzyme's cache
// code left here to re-enable upon Attributor patch
Logic.PPC.FAM.clear(F, F.getName());

#if LLVM_VERSION_MAJOR >= 13 && !defined(FLANG)

Expand Down Expand Up @@ -1819,22 +1819,43 @@ class Enzyme : public ModulePass {
changed = true;
}

if (changed && Logic.PostOpt) {
auto &MAM = Logic.PPC.MAM;
auto &FAM = Logic.PPC.FAM;
PassBuilder PB;
PB.registerModuleAnalyses(MAM);
CGSCCAnalysisManager CGAM;
LoopAnalysisManager LAM;
PB.registerFunctionAnalyses(FAM);
PB.registerLoopAnalyses(LAM);
PB.registerCGSCCAnalyses(CGAM);
FAM.registerPass([&] { return CGSCCAnalysisManagerFunctionProxy(CGAM); });
FAM.registerPass([&] { return LoopAnalysisManagerFunctionProxy(LAM); });
LAM.registerPass([&] { return FunctionAnalysisManagerLoopProxy(FAM); });
MAM.registerPass([&] { return CGSCCAnalysisManagerModuleProxy(CGAM); });
CGAM.registerPass([&] { return ModuleAnalysisManagerCGSCCProxy(MAM); });
#if LLVM_VERSION_MAJOR >= 14
auto PM = PB.buildModuleSimplificationPipeline(OptimizationLevel::O2,
ThinOrFullLTOPhase::None);
#elif LLVM_VERSION_MAJOR >= 12
auto PM = PB.buildModuleSimplificationPipeline(
PassBuilder::OptimizationLevel::O2, ThinOrFullLTOPhase::None);
#else
auto PM = PB.buildModuleSimplificationPipeline(
PassBuilder::OptimizationLevel::O2, PassBuilder::ThinLTOPhase::None);
#endif
PM.run(M, MAM);
#if LLVM_VERSION_MAJOR >= 13
if (Logic.PostOpt) {
if (EnzymeOMPOpt) {
auto &MAM = Logic.PPC.MAM;
auto &FAM = Logic.PPC.FAM;
OpenMPOptPass().run(M, MAM);
/// Attributor is run second time for promoted args to get attributes.
AttributorPass().run(M, MAM);
for (auto &F : M)
if (!F.empty())
PromotePass().run(F, FAM);
changed = true;
}
}
#endif

}
Logic.clear();
return changed;
}
Expand Down
8 changes: 8 additions & 0 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2446,6 +2446,10 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
delete gutils;
tempFunc->eraseFromParent();

// Do not run post processing optimizations if the body of an openmp
// parallel so the adjointgenerator can successfully extract the allocation
// and frees and hoist them into the parent. Optimizing before then may
// make the IR different to traverse, and thus impossible to find the allocs.
if (PostOpt && !omp)
PPC.optimizeIntermediate(NewF);
if (EnzymePrint)
Expand Down Expand Up @@ -3815,6 +3819,10 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
if (Arch == Triple::nvptx || Arch == Triple::nvptx64)
PPC.ReplaceReallocs(nf, /*mem2reg*/ true);

// Do not run post processing optimizations if the body of an openmp
// parallel so the adjointgenerator can successfully extract the allocation
// and frees and hoist them into the parent. Optimizing before then may
// make the IR different to traverse, and thus impossible to find the allocs.
if (PostOpt && !omp)
PPC.optimizeIntermediate(nf);
if (EnzymePrint) {
Expand Down
15 changes: 5 additions & 10 deletions enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Passes/PassBuilder.h"

#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/AssumptionCache.h"
Expand Down Expand Up @@ -2101,18 +2102,12 @@ void PreProcessCache::optimizeIntermediate(Function *F) {
}
}

PassManagerBuilder Builder;
Builder.OptLevel = 2;
legacy::FunctionPassManager PM(F->getParent());
Builder.populateFunctionPassManager(PM);
PM.run(*F);
{
PreservedAnalyses PA;
FAM.invalidate(*F, PA);
}
if (EnzymeCoalese)
CoaleseTrivialMallocs(*F, FAM.getResult<DominatorTreeAnalysis>(*F));
// DCEPass().run(*F, AM);

PreservedAnalyses PA;
FAM.invalidate(*F, PA);
// TODO actually run post optimizations.
}

void PreProcessCache::clear() {
Expand Down
10 changes: 10 additions & 0 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1718,6 +1718,16 @@ Value *GradientUtils::cacheForReverse(IRBuilder<> &BuilderQ, Value *malloc,
auto replacewith =
(idx < 0) ? tape
: lb.CreateExtractValue(tape, {(unsigned)idx});
if (li->getType() != replacewith->getType()) {
llvm::errs() << " oldFunc: " << *oldFunc << "\n";
llvm::errs() << " newFunc: " << *newFunc << "\n";
llvm::errs() << " malloc: " << *malloc << "\n";
llvm::errs() << " li: " << *li << "\n";
llvm::errs() << " u: " << *u << "\n";
llvm::errs() << " replacewith: " << *replacewith
<< " idx=" << idx << " - tape=" << *tape << "\n";
}
assert(li->getType() == replacewith->getType());
li->replaceAllUsesWith(replacewith);
} else {
auto phi =
Expand Down
7 changes: 4 additions & 3 deletions enzyme/test/Enzyme/ReverseMode/ompsqloopPostOpt.ll
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,10 @@ attributes #1 = { argmemonly }
!11 = !{!12}
!12 = !{i64 2, i64 -1, i64 -1, i1 true}


; CHECK-LABEL: define internal void @augmented_.omp_outlined..1(i32* noalias nocapture readonly %.global_tid., i32* noalias nocapture readnone %.bound_tid., i64 %length, double* nocapture nonnull align 8 dereferenceable(8) %tmp, double* nocapture %"tmp'", double** %tape)
; CHECK-LABEL: define internal void @augmented_.omp_outlined..1(i32* noalias nocapture readonly %.global_tid., i32* noalias nocapture readnone %.bound_tid., i64 %length, double* nocapture nonnull align 8 dereferenceable(8) %tmp, double* nocapture readnone %"tmp'", double** nocapture readonly %tape)
; CHECK-NOT: call{{.*}}@malloc
; CHECK: }

; CHECK-LABEL: define internal void @diffe.omp_outlined.(i32* noalias nocapture readonly %.global_tid., i32* noalias nocapture readnone %.bound_tid., i64 %length, double* nocapture nonnull align 8 dereferenceable(8) %tmp, double* nocapture %"tmp'", double** %tapeArg)
; CHECK-LABEL: define internal void @diffe.omp_outlined.(i32* noalias nocapture readonly %.global_tid., i32* noalias nocapture readnone %.bound_tid., i64 %length, double* nocapture nonnull readnone align 8 dereferenceable(8) %tmp, double* nocapture %"tmp'", double** nocapture readonly %tapeArg)
; CHECK-NOT: call{{.*}}@free
; CHECK: }

0 comments on commit a5b2aa3

Please sign in to comment.