Skip to content

Commit

Permalink
feat(loopextract): make loop into nd function
Browse files Browse the repository at this point in the history
This modified LoopExtract pass does the following:
1. Extract all conformant loops into functions
2. Modify the function to set all outputs to ND values

There are various ways the code can be improved:
1. Fix the todos
2. Instead of destructively modifying the function, create another function and point all calls to it

For now, my inclination is to submit as-is since its utility is unknown and I want to avoid over engineering it.

Tested: c-rust tinyvec-capacity-error job (now takes less than a second vis-a-vis 3 minutes earlier)
  • Loading branch information
priyasiddharth authored and agurfinkel committed Aug 11, 2023
1 parent 7cabdad commit ba69e31
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 32 deletions.
2 changes: 2 additions & 0 deletions include/llvm_seahorn/InitializePasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ void initializeSeaInstructionCombiningPassPass(PassRegistry &);
void initializeSeaLoopRotateLegacyPassPass(PassRegistry &);
void initializeSeaLoopUnrollPass(PassRegistry &);
void initializeSeaAnnotation2MetadataLegacyPass(PassRegistry &);
void initializeSeaLoopExtractorLegacyPassPass(PassRegistry &);
void initializeSeaSingleLoopExtractorPass(PassRegistry &);
} // end namespace llvm

#endif // SEA_LLVM_INITIALIZEPASSES_H
2 changes: 2 additions & 0 deletions include/llvm_seahorn/Transforms/IPO.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ class ModulePass;

namespace llvm_seahorn {
llvm::ModulePass *createSeaAnnotation2MetadataLegacyPass();
llvm::ModulePass *createSeaLoopExtractorPass();
llvm::ModulePass *createSeaSingleLoopExtractorPass();
} // namespace llvm_seahorn

#endif
29 changes: 29 additions & 0 deletions include/llvm_seahorn/Transforms/IPO/SeaLoopExtractor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
//
//===----------------------------------------------------------------------===//
//
// A pass wrapper around the ExtractLoop() scalar transformation to extract each
// top-level loop into its own new function. If the loop is the ONLY loop in a
// given function, it is not touched. This is a pass most useful for debugging
// via bugpoint.
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_TRANSFORMS_IPO_LOOPEXTRACTOR_H
#define LLVM_TRANSFORMS_IPO_LOOPEXTRACTOR_H

#include "llvm/IR/PassManager.h"

namespace llvm {

struct SeaLoopExtractorPass : public PassInfoMixin<SeaLoopExtractorPass> {
SeaLoopExtractorPass(unsigned NumLoops = ~0) : NumLoops(NumLoops) {}
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
void printPipeline(raw_ostream &OS,
function_ref<StringRef(StringRef)> MapClassName2PassName);

private:
unsigned NumLoops;
};
} // namespace llvm

#endif // LLVM_TRANSFORMS_IPO_LOOPEXTRACTOR_H
1 change: 1 addition & 0 deletions lib/Transforms/IPO/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_llvm_library(SeaLlvmIpo DISABLE_LLVM_LINK_LLVM_DYLIB
PassManagerBuilder.cpp
Annotation2Metadata.cpp
LoopExtractor.cpp
)
140 changes: 108 additions & 32 deletions lib/Transforms/IPO/LoopExtractor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,21 @@
//
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/IPO/LoopExtractor.h"
#include "llvm_seahorn/InitializePasses.h"
#include "llvm_seahorn/Transforms/IPO.h"
#include "llvm_seahorn/Transforms/IPO/SeaLoopExtractor.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/CallGraphSCCPass.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
#include "llvm/InitializePasses.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Transforms/IPO.h"
Expand All @@ -31,21 +37,85 @@
#include "llvm/Transforms/Utils/CodeExtractor.h"
#include <fstream>
#include <set>

using namespace llvm;

#define DEBUG_TYPE "loop-extract"
#define DEBUG_TYPE "sea-loop-extract"

STATISTIC(NumExtracted, "Number of loops extracted");

DenseMap<const Type *, Function *> m_ndfn;

// TODO: extract into common library for nondet pass and this code
Function &createNewNondetFn(Module &m, Type &type, unsigned num,
std::string prefix) {
std::string name;
unsigned c = num;

do
name = prefix + std::to_string(c++);
while (m.getNamedValue(name));
Function *res =
dyn_cast<Function>(m.getOrInsertFunction(name, &type).getCallee());
assert(res);
return *res;
}

// TODO: extract into common library for nondet pass and this code
Function *getNondetFn(Type *type, Module *m) {
auto it = m_ndfn.find(type);
if (it != m_ndfn.end()) {
return it->second;
}

Function *res =
&createNewNondetFn(*m, *type, m_ndfn.size(), "verifier.nondet.");
m_ndfn[type] = res;
return res;
}

// Replace the given function body with code that stores ND values
// in output args.
void replaceFnBodyWithND(Function *oldfn, SetVector<Value *> &inputs,
SetVector<Value *> &outputs) {
Function *TheFunction = oldfn;
auto ret_ty = TheFunction->getReturnType();
TheFunction->dropAllReferences(); // delete body of function

BasicBlock *BB =
BasicBlock::Create(TheFunction->getContext(), "entry", TheFunction);
IRBuilder<> Builder(TheFunction->getContext());

Builder.SetInsertPoint(BB);

// store nd values in output args
// ASSUME: CodeRegionExtractor creates function formal arg list in the order:
// fn(IN_0, IN_1, ..., IN_N, OUT_0, OUT_1, ..., OUT_M)
for (auto i = inputs.size(); i < inputs.size() + outputs.size(); i++) {
// ASSUME: type is pointer
// TODO: remove use of deprecated getPointerElementType
auto nd_val = Builder.CreateCall(
getNondetFn(TheFunction->getArg(i)->getType()->getPointerElementType(),
TheFunction->getParent()));
Builder.CreateStore(nd_val, TheFunction->getArg(i));
}

// set return value to nd
auto nd_retval =
Builder.CreateCall(getNondetFn(ret_ty, TheFunction->getParent()));
Builder.CreateRet(nd_retval);
verifyFunction(*TheFunction);
}

namespace {
struct LoopExtractorLegacyPass : public ModulePass {
struct SeaLoopExtractorLegacyPass : public ModulePass {
static char ID; // Pass identification, replacement for typeid

unsigned NumLoops;

explicit LoopExtractorLegacyPass(unsigned NumLoops = ~0)
explicit SeaLoopExtractorLegacyPass(unsigned NumLoops = ~0)
: ModulePass(ID), NumLoops(NumLoops) {
initializeLoopExtractorLegacyPassPass(*PassRegistry::getPassRegistry());
initializeSeaLoopExtractorLegacyPassPass(*PassRegistry::getPassRegistry());
}

bool runOnModule(Module &M) override;
Expand All @@ -60,8 +130,8 @@ struct LoopExtractorLegacyPass : public ModulePass {
}
};

struct LoopExtractor {
explicit LoopExtractor(
struct SeaLoopExtractor {
explicit SeaLoopExtractor(
unsigned NumLoops,
function_ref<DominatorTree &(Function &)> LookupDomTree,
function_ref<LoopInfo &(Function &)> LookupLoopInfo,
Expand All @@ -87,34 +157,36 @@ struct LoopExtractor {
};
} // namespace

char LoopExtractorLegacyPass::ID = 0;
INITIALIZE_PASS_BEGIN(LoopExtractorLegacyPass, "loop-extract",
char SeaLoopExtractorLegacyPass::ID = 0;
INITIALIZE_PASS_BEGIN(SeaLoopExtractorLegacyPass, "sea-loop-extract",
"Extract loops into new functions", false, false)
INITIALIZE_PASS_DEPENDENCY(BreakCriticalEdges)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
INITIALIZE_PASS_END(LoopExtractorLegacyPass, "loop-extract",
INITIALIZE_PASS_END(SeaLoopExtractorLegacyPass, "sea-loop-extract",
"Extract loops into new functions", false, false)

namespace {
/// SingleLoopExtractor - For bugpoint.
struct SingleLoopExtractor : public LoopExtractorLegacyPass {
/// SingleLoopExtractor - For bugpoint.
struct SeaSingleLoopExtractor : public SeaLoopExtractorLegacyPass {
static char ID; // Pass identification, replacement for typeid
SingleLoopExtractor() : LoopExtractorLegacyPass(1) {}
SeaSingleLoopExtractor() : SeaLoopExtractorLegacyPass(1) {}
};
} // End anonymous namespace

char SingleLoopExtractor::ID = 0;
INITIALIZE_PASS(SingleLoopExtractor, "loop-extract-single",
char SeaSingleLoopExtractor::ID = 0;
INITIALIZE_PASS(SeaSingleLoopExtractor, "sea-loop-extract-single",
"Extract at most one loop into a new function", false, false)

// createLoopExtractorPass - This pass extracts all natural loops from the
// program into a function if it can.
//
Pass *llvm::createLoopExtractorPass() { return new LoopExtractorLegacyPass(); }
ModulePass *llvm_seahorn::createSeaLoopExtractorPass() {
return new SeaLoopExtractorLegacyPass();
}

bool LoopExtractorLegacyPass::runOnModule(Module &M) {
bool SeaLoopExtractorLegacyPass::runOnModule(Module &M) {
if (skipModule(M))
return false;

Expand All @@ -130,12 +202,12 @@ bool LoopExtractorLegacyPass::runOnModule(Module &M) {
return ACT->lookupAssumptionCache(F);
return nullptr;
};
return LoopExtractor(NumLoops, LookupDomTree, LookupLoopInfo, LookupACT)
return SeaLoopExtractor(NumLoops, LookupDomTree, LookupLoopInfo, LookupACT)
.runOnModule(M) ||
Changed;
}

bool LoopExtractor::runOnModule(Module &M) {
bool SeaLoopExtractor::runOnModule(Module &M) {
if (M.empty())
return false;

Expand Down Expand Up @@ -163,7 +235,7 @@ bool LoopExtractor::runOnModule(Module &M) {
return Changed;
}

bool LoopExtractor::runOnFunction(Function &F) {
bool SeaLoopExtractor::runOnFunction(Function &F) {
// Do not modify `optnone` functions.
if (F.hasOptNone())
return false;
Expand Down Expand Up @@ -222,8 +294,8 @@ bool LoopExtractor::runOnFunction(Function &F) {
return Changed | extractLoops(TLL->begin(), TLL->end(), LI, DT);
}

bool LoopExtractor::extractLoops(Loop::iterator From, Loop::iterator To,
LoopInfo &LI, DominatorTree &DT) {
bool SeaLoopExtractor::extractLoops(Loop::iterator From, Loop::iterator To,
LoopInfo &LI, DominatorTree &DT) {
bool Changed = false;
SmallVector<Loop *, 8> Loops;

Expand All @@ -241,16 +313,19 @@ bool LoopExtractor::extractLoops(Loop::iterator From, Loop::iterator To,
return Changed;
}

bool LoopExtractor::extractLoop(Loop *L, LoopInfo &LI, DominatorTree &DT) {
bool SeaLoopExtractor::extractLoop(Loop *L, LoopInfo &LI, DominatorTree &DT) {
assert(NumLoops != 0);
Function &Func = *L->getHeader()->getParent();
AssumptionCache *AC = LookupAssumptionCache(Func);
CodeExtractorAnalysisCache CEAC(Func);
CodeExtractor Extractor(DT, *L, false, nullptr, nullptr, AC);
if (Extractor.extractCodeRegion(CEAC)) {
SetVector<Value *> inputs, outputs;
auto *newFunction = Extractor.extractCodeRegion(CEAC, inputs, outputs);
if (newFunction) {
LI.erase(L);
--NumLoops;
++NumExtracted;
replaceFnBodyWithND(newFunction, inputs, outputs);
return true;
}
return false;
Expand All @@ -259,11 +334,12 @@ bool LoopExtractor::extractLoop(Loop *L, LoopInfo &LI, DominatorTree &DT) {
// createSingleLoopExtractorPass - This pass extracts one natural loop from the
// program into a function if it can. This is used by bugpoint.
//
Pass *llvm::createSingleLoopExtractorPass() {
return new SingleLoopExtractor();
ModulePass *llvm_seahorn::createSeaSingleLoopExtractorPass() {
return new SeaSingleLoopExtractor();
}

PreservedAnalyses LoopExtractorPass::run(Module &M, ModuleAnalysisManager &AM) {
PreservedAnalyses SeaLoopExtractorPass::run(Module &M,
ModuleAnalysisManager &AM) {
auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
auto LookupDomTree = [&FAM](Function &F) -> DominatorTree & {
return FAM.getResult<DominatorTreeAnalysis>(F);
Expand All @@ -274,8 +350,8 @@ PreservedAnalyses LoopExtractorPass::run(Module &M, ModuleAnalysisManager &AM) {
auto LookupAssumptionCache = [&FAM](Function &F) -> AssumptionCache * {
return FAM.getCachedResult<AssumptionAnalysis>(F);
};
if (!LoopExtractor(NumLoops, LookupDomTree, LookupLoopInfo,
LookupAssumptionCache)
if (!SeaLoopExtractor(NumLoops, LookupDomTree, LookupLoopInfo,
LookupAssumptionCache)
.runOnModule(M))
return PreservedAnalyses::all();

Expand All @@ -284,9 +360,9 @@ PreservedAnalyses LoopExtractorPass::run(Module &M, ModuleAnalysisManager &AM) {
return PA;
}

void LoopExtractorPass::printPipeline(
void SeaLoopExtractorPass::printPipeline(
raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
static_cast<PassInfoMixin<LoopExtractorPass> *>(this)->printPipeline(
static_cast<PassInfoMixin<SeaLoopExtractorPass> *>(this)->printPipeline(
OS, MapClassName2PassName);
OS << "<";
if (NumLoops == 1)
Expand Down

0 comments on commit ba69e31

Please sign in to comment.