diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index b3f76041a5894..1d0ab292c84a2 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -31,6 +31,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Passes/PassBuilder.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" @@ -1629,7 +1630,6 @@ class Enzyme : public ModulePass { // TODO consider enabling when attributor does not delete // dead internal functions, which invalidates Enzyme's cache // code left here to re-enable upon Attributor patch - Logic.PPC.FAM.clear(F, F.getName()); #if LLVM_VERSION_MAJOR >= 13 && !defined(FLANG) @@ -1819,22 +1819,43 @@ class Enzyme : public ModulePass { changed = true; } + if (changed && Logic.PostOpt) { + auto &MAM = Logic.PPC.MAM; + auto &FAM = Logic.PPC.FAM; + PassBuilder PB; + PB.registerModuleAnalyses(MAM); + CGSCCAnalysisManager CGAM; + LoopAnalysisManager LAM; + PB.registerFunctionAnalyses(FAM); + PB.registerLoopAnalyses(LAM); + PB.registerCGSCCAnalyses(CGAM); + FAM.registerPass([&] { return CGSCCAnalysisManagerFunctionProxy(CGAM); }); + FAM.registerPass([&] { return LoopAnalysisManagerFunctionProxy(LAM); }); + LAM.registerPass([&] { return FunctionAnalysisManagerLoopProxy(FAM); }); + MAM.registerPass([&] { return CGSCCAnalysisManagerModuleProxy(CGAM); }); + CGAM.registerPass([&] { return ModuleAnalysisManagerCGSCCProxy(MAM); }); +#if LLVM_VERSION_MAJOR >= 14 + auto PM = PB.buildModuleSimplificationPipeline(OptimizationLevel::O2, + ThinOrFullLTOPhase::None); +#elif LLVM_VERSION_MAJOR >= 12 + auto PM = PB.buildModuleSimplificationPipeline( + PassBuilder::OptimizationLevel::O2, ThinOrFullLTOPhase::None); +#else + auto PM = PB.buildModuleSimplificationPipeline( + PassBuilder::OptimizationLevel::O2, PassBuilder::ThinLTOPhase::None); +#endif + PM.run(M, MAM); #if LLVM_VERSION_MAJOR >= 13 - if (Logic.PostOpt) { if (EnzymeOMPOpt) { - auto &MAM = Logic.PPC.MAM; - auto &FAM = Logic.PPC.FAM; OpenMPOptPass().run(M, MAM); /// Attributor is run second time for promoted args to get attributes. AttributorPass().run(M, MAM); for (auto &F : M) if (!F.empty()) PromotePass().run(F, FAM); - changed = true; } - } #endif - + } Logic.clear(); return changed; } diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 298bc53e2d065..87ea0ccf9b0d8 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -2446,6 +2446,10 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( delete gutils; tempFunc->eraseFromParent(); + // Do not run post processing optimizations if the body of an openmp + // parallel so the adjointgenerator can successfully extract the allocation + // and frees and hoist them into the parent. Optimizing before then may + // make the IR different to traverse, and thus impossible to find the allocs. if (PostOpt && !omp) PPC.optimizeIntermediate(NewF); if (EnzymePrint) @@ -3815,6 +3819,10 @@ Function *EnzymeLogic::CreatePrimalAndGradient( if (Arch == Triple::nvptx || Arch == Triple::nvptx64) PPC.ReplaceReallocs(nf, /*mem2reg*/ true); + // Do not run post processing optimizations if the body of an openmp + // parallel so the adjointgenerator can successfully extract the allocation + // and frees and hoist them into the parent. Optimizing before then may + // make the IR different to traverse, and thus impossible to find the allocs. if (PostOpt && !omp) PPC.optimizeIntermediate(nf); if (EnzymePrint) { diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index e43ba5490ca20..f7cb9f078f70a 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -36,6 +36,7 @@ #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/IR/Verifier.h" +#include "llvm/Passes/PassBuilder.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" @@ -2101,18 +2102,12 @@ void PreProcessCache::optimizeIntermediate(Function *F) { } } - PassManagerBuilder Builder; - Builder.OptLevel = 2; - legacy::FunctionPassManager PM(F->getParent()); - Builder.populateFunctionPassManager(PM); - PM.run(*F); - { - PreservedAnalyses PA; - FAM.invalidate(*F, PA); - } if (EnzymeCoalese) CoaleseTrivialMallocs(*F, FAM.getResult(*F)); - // DCEPass().run(*F, AM); + + PreservedAnalyses PA; + FAM.invalidate(*F, PA); + // TODO actually run post optimizations. } void PreProcessCache::clear() { diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index c5578a7a78cb5..7e129baf9fae3 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -1718,6 +1718,16 @@ Value *GradientUtils::cacheForReverse(IRBuilder<> &BuilderQ, Value *malloc, auto replacewith = (idx < 0) ? tape : lb.CreateExtractValue(tape, {(unsigned)idx}); + if (li->getType() != replacewith->getType()) { + llvm::errs() << " oldFunc: " << *oldFunc << "\n"; + llvm::errs() << " newFunc: " << *newFunc << "\n"; + llvm::errs() << " malloc: " << *malloc << "\n"; + llvm::errs() << " li: " << *li << "\n"; + llvm::errs() << " u: " << *u << "\n"; + llvm::errs() << " replacewith: " << *replacewith + << " idx=" << idx << " - tape=" << *tape << "\n"; + } + assert(li->getType() == replacewith->getType()); li->replaceAllUsesWith(replacewith); } else { auto phi = diff --git a/enzyme/test/Enzyme/ReverseMode/ompsqloopPostOpt.ll b/enzyme/test/Enzyme/ReverseMode/ompsqloopPostOpt.ll index 90750acff1c54..cb80da49022a7 100644 --- a/enzyme/test/Enzyme/ReverseMode/ompsqloopPostOpt.ll +++ b/enzyme/test/Enzyme/ReverseMode/ompsqloopPostOpt.ll @@ -114,9 +114,10 @@ attributes #1 = { argmemonly } !11 = !{!12} !12 = !{i64 2, i64 -1, i64 -1, i1 true} - -; CHECK-LABEL: define internal void @augmented_.omp_outlined..1(i32* noalias nocapture readonly %.global_tid., i32* noalias nocapture readnone %.bound_tid., i64 %length, double* nocapture nonnull align 8 dereferenceable(8) %tmp, double* nocapture %"tmp'", double** %tape) +; CHECK-LABEL: define internal void @augmented_.omp_outlined..1(i32* noalias nocapture readonly %.global_tid., i32* noalias nocapture readnone %.bound_tid., i64 %length, double* nocapture nonnull align 8 dereferenceable(8) %tmp, double* nocapture readnone %"tmp'", double** nocapture readonly %tape) ; CHECK-NOT: call{{.*}}@malloc +; CHECK: } -; CHECK-LABEL: define internal void @diffe.omp_outlined.(i32* noalias nocapture readonly %.global_tid., i32* noalias nocapture readnone %.bound_tid., i64 %length, double* nocapture nonnull align 8 dereferenceable(8) %tmp, double* nocapture %"tmp'", double** %tapeArg) +; CHECK-LABEL: define internal void @diffe.omp_outlined.(i32* noalias nocapture readonly %.global_tid., i32* noalias nocapture readnone %.bound_tid., i64 %length, double* nocapture nonnull readnone align 8 dereferenceable(8) %tmp, double* nocapture %"tmp'", double** nocapture readonly %tapeArg) ; CHECK-NOT: call{{.*}}@free +; CHECK: }