Skip to content

Commit

Permalink
LLVM integrate fixes (#2160)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Nov 7, 2024
1 parent a78b79c commit 1234291
Show file tree
Hide file tree
Showing 11 changed files with 75 additions and 124 deletions.
43 changes: 4 additions & 39 deletions enzyme/Enzyme/ActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1442,20 +1442,9 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) {

if (auto LI = dyn_cast<LoadInst>(TmpOrig))
return isConstantValue(TR, LI->getPointerOperand());
if (isa<IntrinsicInst>(TmpOrig) &&
(cast<IntrinsicInst>(TmpOrig)->getIntrinsicID() ==
Intrinsic::nvvm_ldu_global_i ||
cast<IntrinsicInst>(TmpOrig)->getIntrinsicID() ==
Intrinsic::nvvm_ldu_global_p ||
cast<IntrinsicInst>(TmpOrig)->getIntrinsicID() ==
Intrinsic::nvvm_ldu_global_f ||
cast<IntrinsicInst>(TmpOrig)->getIntrinsicID() ==
Intrinsic::nvvm_ldg_global_i ||
cast<IntrinsicInst>(TmpOrig)->getIntrinsicID() ==
Intrinsic::nvvm_ldg_global_p ||
cast<IntrinsicInst>(TmpOrig)->getIntrinsicID() ==
Intrinsic::nvvm_ldg_global_f))
if (isNVLoad(TmpOrig)) {
return isConstantValue(TR, cast<Instruction>(TmpOrig)->getOperand(0));
}

if (TmpOrig == Val)
return false;
Expand Down Expand Up @@ -1547,19 +1536,7 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) {
TmpOrig);
}
}
} else if (isa<IntrinsicInst>(TmpOrig) &&
(cast<IntrinsicInst>(TmpOrig)->getIntrinsicID() ==
Intrinsic::nvvm_ldu_global_i ||
cast<IntrinsicInst>(TmpOrig)->getIntrinsicID() ==
Intrinsic::nvvm_ldu_global_p ||
cast<IntrinsicInst>(TmpOrig)->getIntrinsicID() ==
Intrinsic::nvvm_ldu_global_f ||
cast<IntrinsicInst>(TmpOrig)->getIntrinsicID() ==
Intrinsic::nvvm_ldg_global_i ||
cast<IntrinsicInst>(TmpOrig)->getIntrinsicID() ==
Intrinsic::nvvm_ldg_global_p ||
cast<IntrinsicInst>(TmpOrig)->getIntrinsicID() ==
Intrinsic::nvvm_ldg_global_f)) {
} else if (isNVLoad(TmpOrig)) {
auto II = cast<IntrinsicInst>(TmpOrig);
if (directions == UP) {
if (isConstantValue(TR, II->getOperand(0))) {
Expand Down Expand Up @@ -1950,19 +1927,7 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) {
isRefSet(AARes)) {
if (EnzymePrintActivity)
llvm::errs() << "potential active load: " << *I << "\n";
if (isa<LoadInst>(I) || (isa<IntrinsicInst>(I) &&
(cast<IntrinsicInst>(I)->getIntrinsicID() ==
Intrinsic::nvvm_ldu_global_i ||
cast<IntrinsicInst>(I)->getIntrinsicID() ==
Intrinsic::nvvm_ldu_global_p ||
cast<IntrinsicInst>(I)->getIntrinsicID() ==
Intrinsic::nvvm_ldu_global_f ||
cast<IntrinsicInst>(I)->getIntrinsicID() ==
Intrinsic::nvvm_ldg_global_i ||
cast<IntrinsicInst>(I)->getIntrinsicID() ==
Intrinsic::nvvm_ldg_global_p ||
cast<IntrinsicInst>(I)->getIntrinsicID() ==
Intrinsic::nvvm_ldg_global_f))) {
if (isa<LoadInst>(I) || isNVLoad(I)) {
// If the ref'ing value is a load check if the loaded value is
// active
if (!Hypothesis->isConstantValue(TR, I)) {
Expand Down
10 changes: 6 additions & 4 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -3725,12 +3725,14 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
Module *M = I.getParent()->getParent()->getParent();

switch (ID) {
case Intrinsic::nvvm_ldu_global_i:
case Intrinsic::nvvm_ldu_global_p:
case Intrinsic::nvvm_ldu_global_f:
#if LLVM_VERSION_MAJOR < 20
case Intrinsic::nvvm_ldg_global_i:
case Intrinsic::nvvm_ldg_global_p:
case Intrinsic::nvvm_ldg_global_f: {
case Intrinsic::nvvm_ldg_global_f:
#endif
case Intrinsic::nvvm_ldu_global_i:
case Intrinsic::nvvm_ldu_global_p:
case Intrinsic::nvvm_ldu_global_f: {
auto CI = cast<ConstantInt>(I.getOperand(1));
visitLoadLike(I, /*Align*/ MaybeAlign(CI->getZExtValue()),
/*constantval*/ false);
Expand Down
11 changes: 1 addition & 10 deletions enzyme/Enzyme/DifferentialUseAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -949,17 +949,8 @@ void DifferentialUseAnalysis::minCut(const DataLayout &DL, LoopInfo &OrigLI,
if (isAllocationCall(V, TLI) || isa<AllocaInst>(V)) {
auto next = (*mp.begin()).V;
bool noncapture = false;
if (isa<LoadInst>(next)) {
if (isa<LoadInst>(next) || isNVLoad(next)) {
noncapture = true;
} else if (auto II = dyn_cast<IntrinsicInst>(next)) {
if (II->getIntrinsicID() == Intrinsic::nvvm_ldu_global_i ||
II->getIntrinsicID() == Intrinsic::nvvm_ldu_global_p ||
II->getIntrinsicID() == Intrinsic::nvvm_ldu_global_f ||
II->getIntrinsicID() == Intrinsic::nvvm_ldg_global_i ||
II->getIntrinsicID() == Intrinsic::nvvm_ldg_global_p ||
II->getIntrinsicID() == Intrinsic::nvvm_ldg_global_f ||
II->getIntrinsicID() == Intrinsic::masked_load)
noncapture = true;
} else if (auto CI = dyn_cast<CallInst>(next)) {
bool captures = false;
for (size_t i = 0; i < CI->arg_size(); i++) {
Expand Down
12 changes: 11 additions & 1 deletion enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3360,7 +3360,13 @@ void augmentPassBuilder(llvm::PassBuilder &PB) {
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(OptimizePM)));
};

auto loadPass = [prePass](ModulePassManager &MPM, OptimizationLevel Level) {
#if LLVM_VERSION_MAJOR >= 20
auto loadPass = [prePass](ModulePassManager &MPM, OptimizationLevel Level,
ThinOrFullLTOPhase)
#else
auto loadPass = [prePass](ModulePassManager &MPM, OptimizationLevel Level)
#endif
{
MPM.addPass(PreserveNVVMNewPM(/*Begin*/ true));

if (!EnzymeEnable)
Expand Down Expand Up @@ -3643,7 +3649,11 @@ void augmentPassBuilder(llvm::PassBuilder &PB) {
LPM.addPass(LoopDeletionPass());
// FIXME: Add loop interchange.

#if LLVM_VERSION_MAJOR >= 20
loadPass(MPM, Level, ThinOrFullLTOPhase::None);
#else
loadPass(MPM, Level);
#endif
};
PB.registerFullLinkTimeOptimizationEarlyEPCallback(loadLTO);
}
Expand Down
30 changes: 7 additions & 23 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,21 +422,14 @@ struct CacheAnalysis {
continue;
for (auto &inst : B) {
// For each load instruction, determine if it is uncacheable.
if (auto op = dyn_cast<LoadInst>(&inst)) {
can_modref_map[op] = is_load_uncacheable(*op);
}
if (auto II = dyn_cast<IntrinsicInst>(&inst)) {
if (isa<LoadInst>(&inst)) {
can_modref_map[&inst] = is_load_uncacheable(inst);
} else if (isNVLoad(&inst)) {
can_modref_map[&inst] = false;
} else if (auto II = dyn_cast<IntrinsicInst>(&inst)) {
switch (II->getIntrinsicID()) {
case Intrinsic::nvvm_ldu_global_i:
case Intrinsic::nvvm_ldu_global_p:
case Intrinsic::nvvm_ldu_global_f:
case Intrinsic::nvvm_ldg_global_i:
case Intrinsic::nvvm_ldg_global_p:
case Intrinsic::nvvm_ldg_global_f:
can_modref_map[II] = false;
break;
case Intrinsic::masked_load:
can_modref_map[II] = is_load_uncacheable(*II);
can_modref_map[&inst] = is_load_uncacheable(inst);
break;
default:
break;
Expand Down Expand Up @@ -5364,20 +5357,11 @@ class TruncateGenerator : public llvm::InstVisitor<TruncateGenerator>,
llvm::SmallVectorImpl<llvm::Value *> &orig_ops) {
using namespace llvm;

switch (ID) {
case Intrinsic::nvvm_ldu_global_i:
case Intrinsic::nvvm_ldu_global_p:
case Intrinsic::nvvm_ldu_global_f:
case Intrinsic::nvvm_ldg_global_i:
case Intrinsic::nvvm_ldg_global_p:
case Intrinsic::nvvm_ldg_global_f: {
if (isNVLoad(&I)) {
auto CI = cast<ConstantInt>(I.getOperand(1));
visitLoadLike(I, /*Align*/ MaybeAlign(CI->getZExtValue()));
return true;
}
default:
break;
}

if (ID == Intrinsic::masked_store) {
auto align0 = cast<ConstantInt>(I.getOperand(2))->getZExtValue();
Expand Down
53 changes: 13 additions & 40 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3904,15 +3904,9 @@ bool GradientUtils::legalRecompute(const Value *val,
if (auto li = dyn_cast<Instruction>(val)) {

const IntrinsicInst *II;
if (isa<LoadInst>(li) ||
if (isa<LoadInst>(li) || isNVLoad(li) ||
((II = dyn_cast<IntrinsicInst>(li)) &&
(II->getIntrinsicID() == Intrinsic::nvvm_ldu_global_i ||
II->getIntrinsicID() == Intrinsic::nvvm_ldu_global_p ||
II->getIntrinsicID() == Intrinsic::nvvm_ldu_global_f ||
II->getIntrinsicID() == Intrinsic::nvvm_ldg_global_i ||
II->getIntrinsicID() == Intrinsic::nvvm_ldg_global_p ||
II->getIntrinsicID() == Intrinsic::nvvm_ldg_global_f ||
II->getIntrinsicID() == Intrinsic::masked_load))) {
(II->getIntrinsicID() == Intrinsic::masked_load))) {
// If this is an already unwrapped value, legal to recompute again.
if (unwrappedLoads.find(li) != unwrappedLoads.end())
return legalRecompute(unwrappedLoads.find(li)->second, available,
Expand Down Expand Up @@ -4174,7 +4168,7 @@ bool GradientUtils::shouldRecompute(const Value *val,
}

if (auto op = dyn_cast<IntrinsicInst>(val)) {
if (!op->mayReadOrWriteMemory() || isReadNone(op))
if (!op->mayReadOrWriteMemory() || isReadNone(op) || isNVLoad(op))
return true;
switch (op->getIntrinsicID()) {
case Intrinsic::sin:
Expand All @@ -4186,12 +4180,6 @@ bool GradientUtils::shouldRecompute(const Value *val,
case Intrinsic::sinh:
#endif
case Intrinsic::log:
case Intrinsic::nvvm_ldu_global_i:
case Intrinsic::nvvm_ldu_global_p:
case Intrinsic::nvvm_ldu_global_f:
case Intrinsic::nvvm_ldg_global_i:
case Intrinsic::nvvm_ldg_global_p:
case Intrinsic::nvvm_ldg_global_f:
return true;
default:
return false;
Expand Down Expand Up @@ -6109,12 +6097,14 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM,
switch (II->getIntrinsicID()) {
default:
goto end;
case Intrinsic::nvvm_ldu_global_i:
case Intrinsic::nvvm_ldu_global_p:
case Intrinsic::nvvm_ldu_global_f:
#if LLVM_VERSION_MAJOR < 20
case Intrinsic::nvvm_ldg_global_i:
case Intrinsic::nvvm_ldg_global_p:
case Intrinsic::nvvm_ldg_global_f: {
case Intrinsic::nvvm_ldg_global_f:
#endif
case Intrinsic::nvvm_ldu_global_i:
case Intrinsic::nvvm_ldu_global_p:
case Intrinsic::nvvm_ldu_global_f: {
return applyChainRule(
II->getType(), bb,
[&](Value *ptr) {
Expand Down Expand Up @@ -6388,19 +6378,8 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
bool reduceRegister = false;

if (EnzymeRegisterReduce) {
if (auto II = dyn_cast<IntrinsicInst>(inst)) {
switch (II->getIntrinsicID()) {
case Intrinsic::nvvm_ldu_global_i:
case Intrinsic::nvvm_ldu_global_p:
case Intrinsic::nvvm_ldu_global_f:
case Intrinsic::nvvm_ldg_global_i:
case Intrinsic::nvvm_ldg_global_p:
case Intrinsic::nvvm_ldg_global_f:
reduceRegister = true;
break;
default:
break;
}
if (isNVLoad(inst)) {
reduceRegister = true;
}
if (auto LI = dyn_cast<LoadInst>(inst)) {
auto Arch =
Expand Down Expand Up @@ -9526,17 +9505,11 @@ bool GradientUtils::needsCacheWholeAllocation(
continue;
seen.insert(pair);
// Loads are always fine
if (isa<LoadInst>(cur))
if (isa<LoadInst>(cur) || isNVLoad(cur))
continue;

if (auto II = dyn_cast<IntrinsicInst>(cur))
if (II->getIntrinsicID() == Intrinsic::nvvm_ldu_global_i ||
II->getIntrinsicID() == Intrinsic::nvvm_ldu_global_p ||
II->getIntrinsicID() == Intrinsic::nvvm_ldu_global_f ||
II->getIntrinsicID() == Intrinsic::nvvm_ldg_global_i ||
II->getIntrinsicID() == Intrinsic::nvvm_ldg_global_p ||
II->getIntrinsicID() == Intrinsic::nvvm_ldg_global_f ||
II->getIntrinsicID() == Intrinsic::masked_load)
if (II->getIntrinsicID() == Intrinsic::masked_load)
continue;

bool returnedSameValue = false;
Expand Down
1 change: 1 addition & 0 deletions enzyme/Enzyme/MLIR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ set(LIBS
MLIREnzymeTransforms
MLIREnzyme
MLIROptLib
MLIRFuncInlinerExtension
)
add_llvm_executable(enzymemlir-opt enzymemlir-opt.cpp)

Expand Down
6 changes: 3 additions & 3 deletions enzyme/Enzyme/MLIR/Passes/AddToOpToIndexAndLoad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,12 @@ struct AddToOpToIndexAndLoadPass
cacheBuilder.setInsertionPoint(terminator);

// Is it a fine assumption that all indexing maps are the same?
for (int i = 0; i < map[0].getNumDims(); i++) {
for (size_t i = 0; i < map[0].getNumDims(); i++) {
indices.push_back(cacheBuilder.create<linalg::IndexOp>(loc, i));
}

SmallVector<Value> rets;
for (int i = 0; i < retargs.size(); i++) {
for (size_t i = 0; i < retargs.size(); i++) {
// auto load = cacheBuilder.create<AffineLoadOp>(loc, inputs[i], map[i],
// indices); auto store = cacheBuilder.create<AffineStoreOp>(loc, load,
// inputs[i], map[i], indices);
Expand All @@ -95,7 +95,7 @@ struct AddToOpToIndexAndLoadPass
mapAppliedIndices);
}

for (int i = 0; i < retargs.size(); i++) {
for (size_t i = 0; i < retargs.size(); i++) {
SmallVector<Value> mapAppliedIndices =
applyAffineMap(map[num_ins + i], indices, cacheBuilder, loc);
auto load = cacheBuilder.create<memref::LoadOp>(loc, outs[i],
Expand Down
10 changes: 6 additions & 4 deletions enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3898,12 +3898,14 @@ void TypeAnalyzer::visitIntrinsicInst(llvm::IntrinsicInst &I) {
return;
}

case Intrinsic::nvvm_ldu_global_i:
case Intrinsic::nvvm_ldu_global_p:
case Intrinsic::nvvm_ldu_global_f:
#if LLVM_VERSION_MAJOR < 20
case Intrinsic::nvvm_ldg_global_i:
case Intrinsic::nvvm_ldg_global_p:
case Intrinsic::nvvm_ldg_global_f: {
case Intrinsic::nvvm_ldg_global_f:
#endif
case Intrinsic::nvvm_ldu_global_i:
case Intrinsic::nvvm_ldu_global_p:
case Intrinsic::nvvm_ldu_global_f: {
auto &DL = I.getParent()->getParent()->getParent()->getDataLayout();
auto LoadSize = (DL.getTypeSizeInBits(I.getType()) + 7) / 8;

Expand Down
20 changes: 20 additions & 0 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3790,3 +3790,23 @@ void dumpBlock(llvm::BasicBlock *blk) { llvm::errs() << *blk << "\n"; }
void dumpType(llvm::Type *ty) { llvm::errs() << *ty << "\n"; }

void dumpTypeResults(TypeResults &TR) { TR.dump(); }

bool isNVLoad(const llvm::Value *V) {
auto II = dyn_cast<IntrinsicInst>(V);
if (!II)
return false;
switch (II->getIntrinsicID()) {
case Intrinsic::nvvm_ldu_global_i:
case Intrinsic::nvvm_ldu_global_p:
case Intrinsic::nvvm_ldu_global_f:
#if LLVM_VERSION_MAJOR < 20
case Intrinsic::nvvm_ldg_global_i:
case Intrinsic::nvvm_ldg_global_p:
case Intrinsic::nvvm_ldg_global_f:
#endif
return true;
default:
return false;
}
return false;
}
3 changes: 3 additions & 0 deletions enzyme/Enzyme/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -2095,4 +2095,7 @@ llvm::CallInst *createIntrinsicCall(llvm::IRBuilderBase &B,
llvm::ArrayRef<llvm::Value *> Args,
llvm::Instruction *FMFSource = nullptr,
const llvm::Twine &Name = "");

bool isNVLoad(const llvm::Value *V);

#endif // ENZYME_UTILS_H

0 comments on commit 1234291

Please sign in to comment.