From 1234291fea3010c16f7eecb0d7e684c59ab61b70 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 7 Nov 2024 17:37:58 -0600 Subject: [PATCH] LLVM integrate fixes (#2160) --- enzyme/Enzyme/ActivityAnalysis.cpp | 43 ++------------- enzyme/Enzyme/AdjointGenerator.h | 10 ++-- enzyme/Enzyme/DifferentialUseAnalysis.cpp | 11 +--- enzyme/Enzyme/Enzyme.cpp | 12 ++++- enzyme/Enzyme/EnzymeLogic.cpp | 30 +++-------- enzyme/Enzyme/GradientUtils.cpp | 53 +++++-------------- enzyme/Enzyme/MLIR/CMakeLists.txt | 1 + .../MLIR/Passes/AddToOpToIndexAndLoad.cpp | 6 +-- enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 10 ++-- enzyme/Enzyme/Utils.cpp | 20 +++++++ enzyme/Enzyme/Utils.h | 3 ++ 11 files changed, 75 insertions(+), 124 deletions(-) diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index 8c8b718a0e83..cedaa1d19c3f 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -1442,20 +1442,9 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { if (auto LI = dyn_cast(TmpOrig)) return isConstantValue(TR, LI->getPointerOperand()); - if (isa(TmpOrig) && - (cast(TmpOrig)->getIntrinsicID() == - Intrinsic::nvvm_ldu_global_i || - cast(TmpOrig)->getIntrinsicID() == - Intrinsic::nvvm_ldu_global_p || - cast(TmpOrig)->getIntrinsicID() == - Intrinsic::nvvm_ldu_global_f || - cast(TmpOrig)->getIntrinsicID() == - Intrinsic::nvvm_ldg_global_i || - cast(TmpOrig)->getIntrinsicID() == - Intrinsic::nvvm_ldg_global_p || - cast(TmpOrig)->getIntrinsicID() == - Intrinsic::nvvm_ldg_global_f)) + if (isNVLoad(TmpOrig)) { return isConstantValue(TR, cast(TmpOrig)->getOperand(0)); + } if (TmpOrig == Val) return false; @@ -1547,19 +1536,7 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { TmpOrig); } } - } else if (isa(TmpOrig) && - (cast(TmpOrig)->getIntrinsicID() == - Intrinsic::nvvm_ldu_global_i || - cast(TmpOrig)->getIntrinsicID() == - Intrinsic::nvvm_ldu_global_p || - cast(TmpOrig)->getIntrinsicID() == - Intrinsic::nvvm_ldu_global_f || - cast(TmpOrig)->getIntrinsicID() == - Intrinsic::nvvm_ldg_global_i || - cast(TmpOrig)->getIntrinsicID() == - Intrinsic::nvvm_ldg_global_p || - cast(TmpOrig)->getIntrinsicID() == - Intrinsic::nvvm_ldg_global_f)) { + } else if (isNVLoad(TmpOrig)) { auto II = cast(TmpOrig); if (directions == UP) { if (isConstantValue(TR, II->getOperand(0))) { @@ -1950,19 +1927,7 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { isRefSet(AARes)) { if (EnzymePrintActivity) llvm::errs() << "potential active load: " << *I << "\n"; - if (isa(I) || (isa(I) && - (cast(I)->getIntrinsicID() == - Intrinsic::nvvm_ldu_global_i || - cast(I)->getIntrinsicID() == - Intrinsic::nvvm_ldu_global_p || - cast(I)->getIntrinsicID() == - Intrinsic::nvvm_ldu_global_f || - cast(I)->getIntrinsicID() == - Intrinsic::nvvm_ldg_global_i || - cast(I)->getIntrinsicID() == - Intrinsic::nvvm_ldg_global_p || - cast(I)->getIntrinsicID() == - Intrinsic::nvvm_ldg_global_f))) { + if (isa(I) || isNVLoad(I)) { // If the ref'ing value is a load check if the loaded value is // active if (!Hypothesis->isConstantValue(TR, I)) { diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 952027337fad..2eae03312b73 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -3725,12 +3725,14 @@ class AdjointGenerator : public llvm::InstVisitor { Module *M = I.getParent()->getParent()->getParent(); switch (ID) { - case Intrinsic::nvvm_ldu_global_i: - case Intrinsic::nvvm_ldu_global_p: - case Intrinsic::nvvm_ldu_global_f: +#if LLVM_VERSION_MAJOR < 20 case Intrinsic::nvvm_ldg_global_i: case Intrinsic::nvvm_ldg_global_p: - case Intrinsic::nvvm_ldg_global_f: { + case Intrinsic::nvvm_ldg_global_f: +#endif + case Intrinsic::nvvm_ldu_global_i: + case Intrinsic::nvvm_ldu_global_p: + case Intrinsic::nvvm_ldu_global_f: { auto CI = cast(I.getOperand(1)); visitLoadLike(I, /*Align*/ MaybeAlign(CI->getZExtValue()), /*constantval*/ false); diff --git a/enzyme/Enzyme/DifferentialUseAnalysis.cpp b/enzyme/Enzyme/DifferentialUseAnalysis.cpp index 5b4ef2bdc7c9..3d3722cfe634 100644 --- a/enzyme/Enzyme/DifferentialUseAnalysis.cpp +++ b/enzyme/Enzyme/DifferentialUseAnalysis.cpp @@ -949,17 +949,8 @@ void DifferentialUseAnalysis::minCut(const DataLayout &DL, LoopInfo &OrigLI, if (isAllocationCall(V, TLI) || isa(V)) { auto next = (*mp.begin()).V; bool noncapture = false; - if (isa(next)) { + if (isa(next) || isNVLoad(next)) { noncapture = true; - } else if (auto II = dyn_cast(next)) { - if (II->getIntrinsicID() == Intrinsic::nvvm_ldu_global_i || - II->getIntrinsicID() == Intrinsic::nvvm_ldu_global_p || - II->getIntrinsicID() == Intrinsic::nvvm_ldu_global_f || - II->getIntrinsicID() == Intrinsic::nvvm_ldg_global_i || - II->getIntrinsicID() == Intrinsic::nvvm_ldg_global_p || - II->getIntrinsicID() == Intrinsic::nvvm_ldg_global_f || - II->getIntrinsicID() == Intrinsic::masked_load) - noncapture = true; } else if (auto CI = dyn_cast(next)) { bool captures = false; for (size_t i = 0; i < CI->arg_size(); i++) { diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 230a5d25ff20..c33f66ff95b3 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -3360,7 +3360,13 @@ void augmentPassBuilder(llvm::PassBuilder &PB) { MPM.addPass(createModuleToFunctionPassAdaptor(std::move(OptimizePM))); }; - auto loadPass = [prePass](ModulePassManager &MPM, OptimizationLevel Level) { +#if LLVM_VERSION_MAJOR >= 20 + auto loadPass = [prePass](ModulePassManager &MPM, OptimizationLevel Level, + ThinOrFullLTOPhase) +#else + auto loadPass = [prePass](ModulePassManager &MPM, OptimizationLevel Level) +#endif + { MPM.addPass(PreserveNVVMNewPM(/*Begin*/ true)); if (!EnzymeEnable) @@ -3643,7 +3649,11 @@ void augmentPassBuilder(llvm::PassBuilder &PB) { LPM.addPass(LoopDeletionPass()); // FIXME: Add loop interchange. +#if LLVM_VERSION_MAJOR >= 20 + loadPass(MPM, Level, ThinOrFullLTOPhase::None); +#else loadPass(MPM, Level); +#endif }; PB.registerFullLinkTimeOptimizationEarlyEPCallback(loadLTO); } diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index d6707e0592e9..4d4b2b5cbc05 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -422,21 +422,14 @@ struct CacheAnalysis { continue; for (auto &inst : B) { // For each load instruction, determine if it is uncacheable. - if (auto op = dyn_cast(&inst)) { - can_modref_map[op] = is_load_uncacheable(*op); - } - if (auto II = dyn_cast(&inst)) { + if (isa(&inst)) { + can_modref_map[&inst] = is_load_uncacheable(inst); + } else if (isNVLoad(&inst)) { + can_modref_map[&inst] = false; + } else if (auto II = dyn_cast(&inst)) { switch (II->getIntrinsicID()) { - case Intrinsic::nvvm_ldu_global_i: - case Intrinsic::nvvm_ldu_global_p: - case Intrinsic::nvvm_ldu_global_f: - case Intrinsic::nvvm_ldg_global_i: - case Intrinsic::nvvm_ldg_global_p: - case Intrinsic::nvvm_ldg_global_f: - can_modref_map[II] = false; - break; case Intrinsic::masked_load: - can_modref_map[II] = is_load_uncacheable(*II); + can_modref_map[&inst] = is_load_uncacheable(inst); break; default: break; @@ -5364,20 +5357,11 @@ class TruncateGenerator : public llvm::InstVisitor, llvm::SmallVectorImpl &orig_ops) { using namespace llvm; - switch (ID) { - case Intrinsic::nvvm_ldu_global_i: - case Intrinsic::nvvm_ldu_global_p: - case Intrinsic::nvvm_ldu_global_f: - case Intrinsic::nvvm_ldg_global_i: - case Intrinsic::nvvm_ldg_global_p: - case Intrinsic::nvvm_ldg_global_f: { + if (isNVLoad(&I)) { auto CI = cast(I.getOperand(1)); visitLoadLike(I, /*Align*/ MaybeAlign(CI->getZExtValue())); return true; } - default: - break; - } if (ID == Intrinsic::masked_store) { auto align0 = cast(I.getOperand(2))->getZExtValue(); diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index c116cafa0550..7c012506e374 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -3904,15 +3904,9 @@ bool GradientUtils::legalRecompute(const Value *val, if (auto li = dyn_cast(val)) { const IntrinsicInst *II; - if (isa(li) || + if (isa(li) || isNVLoad(li) || ((II = dyn_cast(li)) && - (II->getIntrinsicID() == Intrinsic::nvvm_ldu_global_i || - II->getIntrinsicID() == Intrinsic::nvvm_ldu_global_p || - II->getIntrinsicID() == Intrinsic::nvvm_ldu_global_f || - II->getIntrinsicID() == Intrinsic::nvvm_ldg_global_i || - II->getIntrinsicID() == Intrinsic::nvvm_ldg_global_p || - II->getIntrinsicID() == Intrinsic::nvvm_ldg_global_f || - II->getIntrinsicID() == Intrinsic::masked_load))) { + (II->getIntrinsicID() == Intrinsic::masked_load))) { // If this is an already unwrapped value, legal to recompute again. if (unwrappedLoads.find(li) != unwrappedLoads.end()) return legalRecompute(unwrappedLoads.find(li)->second, available, @@ -4174,7 +4168,7 @@ bool GradientUtils::shouldRecompute(const Value *val, } if (auto op = dyn_cast(val)) { - if (!op->mayReadOrWriteMemory() || isReadNone(op)) + if (!op->mayReadOrWriteMemory() || isReadNone(op) || isNVLoad(op)) return true; switch (op->getIntrinsicID()) { case Intrinsic::sin: @@ -4186,12 +4180,6 @@ bool GradientUtils::shouldRecompute(const Value *val, case Intrinsic::sinh: #endif case Intrinsic::log: - case Intrinsic::nvvm_ldu_global_i: - case Intrinsic::nvvm_ldu_global_p: - case Intrinsic::nvvm_ldu_global_f: - case Intrinsic::nvvm_ldg_global_i: - case Intrinsic::nvvm_ldg_global_p: - case Intrinsic::nvvm_ldg_global_f: return true; default: return false; @@ -6109,12 +6097,14 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, switch (II->getIntrinsicID()) { default: goto end; - case Intrinsic::nvvm_ldu_global_i: - case Intrinsic::nvvm_ldu_global_p: - case Intrinsic::nvvm_ldu_global_f: +#if LLVM_VERSION_MAJOR < 20 case Intrinsic::nvvm_ldg_global_i: case Intrinsic::nvvm_ldg_global_p: - case Intrinsic::nvvm_ldg_global_f: { + case Intrinsic::nvvm_ldg_global_f: +#endif + case Intrinsic::nvvm_ldu_global_i: + case Intrinsic::nvvm_ldu_global_p: + case Intrinsic::nvvm_ldu_global_f: { return applyChainRule( II->getType(), bb, [&](Value *ptr) { @@ -6388,19 +6378,8 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, bool reduceRegister = false; if (EnzymeRegisterReduce) { - if (auto II = dyn_cast(inst)) { - switch (II->getIntrinsicID()) { - case Intrinsic::nvvm_ldu_global_i: - case Intrinsic::nvvm_ldu_global_p: - case Intrinsic::nvvm_ldu_global_f: - case Intrinsic::nvvm_ldg_global_i: - case Intrinsic::nvvm_ldg_global_p: - case Intrinsic::nvvm_ldg_global_f: - reduceRegister = true; - break; - default: - break; - } + if (isNVLoad(inst)) { + reduceRegister = true; } if (auto LI = dyn_cast(inst)) { auto Arch = @@ -9526,17 +9505,11 @@ bool GradientUtils::needsCacheWholeAllocation( continue; seen.insert(pair); // Loads are always fine - if (isa(cur)) + if (isa(cur) || isNVLoad(cur)) continue; if (auto II = dyn_cast(cur)) - if (II->getIntrinsicID() == Intrinsic::nvvm_ldu_global_i || - II->getIntrinsicID() == Intrinsic::nvvm_ldu_global_p || - II->getIntrinsicID() == Intrinsic::nvvm_ldu_global_f || - II->getIntrinsicID() == Intrinsic::nvvm_ldg_global_i || - II->getIntrinsicID() == Intrinsic::nvvm_ldg_global_p || - II->getIntrinsicID() == Intrinsic::nvvm_ldg_global_f || - II->getIntrinsicID() == Intrinsic::masked_load) + if (II->getIntrinsicID() == Intrinsic::masked_load) continue; bool returnedSameValue = false; diff --git a/enzyme/Enzyme/MLIR/CMakeLists.txt b/enzyme/Enzyme/MLIR/CMakeLists.txt index 14bd1be4424f..582648d53193 100644 --- a/enzyme/Enzyme/MLIR/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/CMakeLists.txt @@ -16,6 +16,7 @@ set(LIBS MLIREnzymeTransforms MLIREnzyme MLIROptLib + MLIRFuncInlinerExtension ) add_llvm_executable(enzymemlir-opt enzymemlir-opt.cpp) diff --git a/enzyme/Enzyme/MLIR/Passes/AddToOpToIndexAndLoad.cpp b/enzyme/Enzyme/MLIR/Passes/AddToOpToIndexAndLoad.cpp index 648509f58e52..be1688187e37 100644 --- a/enzyme/Enzyme/MLIR/Passes/AddToOpToIndexAndLoad.cpp +++ b/enzyme/Enzyme/MLIR/Passes/AddToOpToIndexAndLoad.cpp @@ -76,12 +76,12 @@ struct AddToOpToIndexAndLoadPass cacheBuilder.setInsertionPoint(terminator); // Is it a fine assumption that all indexing maps are the same? - for (int i = 0; i < map[0].getNumDims(); i++) { + for (size_t i = 0; i < map[0].getNumDims(); i++) { indices.push_back(cacheBuilder.create(loc, i)); } SmallVector rets; - for (int i = 0; i < retargs.size(); i++) { + for (size_t i = 0; i < retargs.size(); i++) { // auto load = cacheBuilder.create(loc, inputs[i], map[i], // indices); auto store = cacheBuilder.create(loc, load, // inputs[i], map[i], indices); @@ -95,7 +95,7 @@ struct AddToOpToIndexAndLoadPass mapAppliedIndices); } - for (int i = 0; i < retargs.size(); i++) { + for (size_t i = 0; i < retargs.size(); i++) { SmallVector mapAppliedIndices = applyAffineMap(map[num_ins + i], indices, cacheBuilder, loc); auto load = cacheBuilder.create(loc, outs[i], diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index fb41ea409e82..c77d86efc554 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -3898,12 +3898,14 @@ void TypeAnalyzer::visitIntrinsicInst(llvm::IntrinsicInst &I) { return; } - case Intrinsic::nvvm_ldu_global_i: - case Intrinsic::nvvm_ldu_global_p: - case Intrinsic::nvvm_ldu_global_f: +#if LLVM_VERSION_MAJOR < 20 case Intrinsic::nvvm_ldg_global_i: case Intrinsic::nvvm_ldg_global_p: - case Intrinsic::nvvm_ldg_global_f: { + case Intrinsic::nvvm_ldg_global_f: +#endif + case Intrinsic::nvvm_ldu_global_i: + case Intrinsic::nvvm_ldu_global_p: + case Intrinsic::nvvm_ldu_global_f: { auto &DL = I.getParent()->getParent()->getParent()->getDataLayout(); auto LoadSize = (DL.getTypeSizeInBits(I.getType()) + 7) / 8; diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index 6723ff9e2453..2e78f7c083cd 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -3790,3 +3790,23 @@ void dumpBlock(llvm::BasicBlock *blk) { llvm::errs() << *blk << "\n"; } void dumpType(llvm::Type *ty) { llvm::errs() << *ty << "\n"; } void dumpTypeResults(TypeResults &TR) { TR.dump(); } + +bool isNVLoad(const llvm::Value *V) { + auto II = dyn_cast(V); + if (!II) + return false; + switch (II->getIntrinsicID()) { + case Intrinsic::nvvm_ldu_global_i: + case Intrinsic::nvvm_ldu_global_p: + case Intrinsic::nvvm_ldu_global_f: +#if LLVM_VERSION_MAJOR < 20 + case Intrinsic::nvvm_ldg_global_i: + case Intrinsic::nvvm_ldg_global_p: + case Intrinsic::nvvm_ldg_global_f: +#endif + return true; + default: + return false; + } + return false; +} diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index 7c20b9f7c6b2..f825af47607c 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -2095,4 +2095,7 @@ llvm::CallInst *createIntrinsicCall(llvm::IRBuilderBase &B, llvm::ArrayRef Args, llvm::Instruction *FMFSource = nullptr, const llvm::Twine &Name = ""); + +bool isNVLoad(const llvm::Value *V); + #endif // ENZYME_UTILS_H