Skip to content

Commit

Permalink
#1: start implementing empire:: namespace transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
lifflander committed Feb 16, 2021
1 parent 28bdb6c commit 520e685
Show file tree
Hide file tree
Showing 2 changed files with 208 additions and 40 deletions.
11 changes: 10 additions & 1 deletion src/test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ namespace Kokkos {

void parallel_for(int i, int j) { }
void parallel_scan(int i, int j) { }
void fence() { }

}

Expand All @@ -12,18 +13,26 @@ void test() {
while (true) {
int x = 10;
if (true) x=5;
parallel_for(10, 20);
parallel_for(80, 90);
fence();
if (false) x=6;
}
}

void test2() {
if (true) {
Kokkos::parallel_for(10, 20);
Kokkos::fence();
Kokkos::parallel_scan(10, 20);
Kokkos::fence();
}
if (true) {
Kokkos::parallel_for(30, 40);
Kokkos::fence();
}
if (true) {
Kokkos::parallel_for(1219, 23);
test();
}
}

Expand Down
237 changes: 198 additions & 39 deletions src/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,111 @@ StatementMatcher CallMatcher =
)
).bind("callExpr");

StatementMatcher FenceMatcher =
callExpr(
callee(
functionDecl(
hasName("::Kokkos::fence")
)
)
).bind("fenceExpr");

struct FenceCallback : MatchFinder::MatchCallback {
virtual void run(const MatchFinder::MatchResult &Result) {
if (CallExpr const *ce = Result.Nodes.getNodeAs<clang::CallExpr>("fenceExpr")) {
found = true;
}
}
bool found = false;
};

template <typename T>
auto getBegin(T const& t) {
#if LLVM_VERSION_MAJOR > 7
return t->getBeginLoc();
#else
return t->getBeginLoc();
#endif
}

template <typename T>
auto getEnd(T const& t) {
#if LLVM_VERSION_MAJOR > 7
return t->getEndLoc();
#else
return t->getLocEnd();
#endif
}

struct RewriteBlocking {
explicit RewriteBlocking(Rewriter& in_rw)
: rw(in_rw)
{ }

void operator()(CallExpr const* par, CallExpr const* fence) const {
// Change the namespace
{
auto begin = getBegin(par);

auto str = rw.getRewrittenText(SourceRange(begin, begin.getLocWithOffset(6)));
fmt::print("str={}\n", str);
if (str == "Kokkos::") {
rw.ReplaceText(begin, 6, "empire");
} else {
rw.InsertTextBefore(begin, "empire::");
}
}

// Switch to parallel_* blocking
{
auto end = getEnd(par->getCallee());
rw.InsertTextAfterToken(end, "_blocking");
}

// Remove the fence line
{
auto begin = getBegin(fence);
auto end = getEnd(fence);
typename Rewriter::RewriteOptions ro;
ro.RemoveLineIfEmpty = true;
rw.RemoveText(SourceRange{begin, end.getLocWithOffset(1)}, ro);
}
}

private:
Rewriter& rw;
};

struct RewriteAsync {
explicit RewriteAsync(Rewriter& in_rw)
: rw(in_rw)
{ }

void operator()(CallExpr const* par) const {
// Change the namespace
{
auto begin = getBegin(par);

auto str = rw.getRewrittenText(SourceRange(begin, begin.getLocWithOffset(6)));
fmt::print("str={}\n", str);
if (str == "Kokkos::") {
rw.ReplaceText(begin, 6, "empire");
} else {
rw.InsertTextBefore(begin, "empire::");
}
}

// Switch to parallel_* async
{
auto end = getEnd(par->getCallee());
rw.InsertTextAfterToken(end, "_async");
}
}

private:
Rewriter& rw;
};

struct ParallelForRewriter : MatchFinder::MatchCallback {
explicit ParallelForRewriter(Rewriter& in_rw)
: rw(in_rw)
Expand All @@ -102,52 +207,106 @@ struct ParallelForRewriter : MatchFinder::MatchCallback {
virtual void run(const MatchFinder::MatchResult &Result) {
if (CallExpr const *ce = Result.Nodes.getNodeAs<clang::CallExpr>("callExpr")) {

// Only match files based on user's input
if (Matcher != "") {
auto file = rw.getSourceMgr().getFilename(ce->getEndLoc());
auto& ctx = Result.Context;
auto p = ctx->getParents(*ce);
fmt::print("size={}\n", p.size());

fmt::print("considering filename={}, regex={}\n", file.str(), Matcher);

std::regex re(Matcher);
std::cmatch m;
if (std::regex_match(file.str().c_str(), m, re)) {
fmt::print("=== Invoking rewriter on matched result in {} ===\n", file.str());
// we need to process this match
} else {
// break out and ignore this file
return;
}
if (p.size() != 1) {
return;
}

if (processed_files.find(file.str()) != processed_files.end()) {
fmt::print("already generated for filename={}\n", file.str());
return;
bool found_fence = false;
CallExpr const* fence = nullptr;
for (std::size_t i = 0; i < p.size(); i++) {
Stmt const* st = p[i].get<Stmt>();
//st->dumpColor();
if (isa<CompoundStmt>(st)) {
auto const& cs = cast<CompoundStmt>(st);
for (auto iter = cs->child_begin(); iter != cs->child_end(); ++iter) {
if (*iter == ce) {
iter++;
if (iter != cs->child_end()) {
MatchFinder fence_matcher;
auto fc = std::make_unique<FenceCallback>();
fence_matcher.addMatcher(FenceMatcher, fc.get());
fence_matcher.match(**iter, *Result.Context);
found_fence = fc->found;
if (fc->found) {
fence = cast<CallExpr>(*iter);
///fence = *iter;
fmt::print("FOUND fence\n");
} else {
fmt::print("NOT FOUND fence\n");
}

}
}
}
}

new_processed_files.insert(file.str());
break;
}

fmt::print(
"Traversing function \"{}\" ptr={}\n",
ce->getDirectCallee()->getNameInfo().getAsString(),
static_cast<void const*>(ce)
);
ce->dumpPretty(ce->getDirectCallee()->getASTContext());
fmt::print("\n");
ce->dumpColor();

#if CLANG5
auto loc = ce->getLocEnd();
#else
auto loc = ce->getEndLoc();
#endif

int offset = Lexer::MeasureTokenLength(loc, rw.getSourceMgr(), rw.getLangOpts()) + 1;
if (found_fence) {
// rewrite to blocking with empire
auto rb = std::make_unique<RewriteBlocking>(rw);
rb->operator()(ce, fence);
} else {
// rewrite to async with empire
auto ra = std::make_unique<RewriteAsync>(rw);
ra->operator()(ce);
}

SourceLocation loc2 = loc.getLocWithOffset(offset);
rw.InsertText(loc2, "\nKokkos::fence();", true, true);
} else {
fmt::print(stderr, "traversing something unknown?\n");
exit(1);
// // Only match files based on user's input
// if (Matcher != "") {
// #if LLVM_VERSION_MAJOR > 7
// auto file = rw.getSourceMgr().getFilename(ce->getEndLoc());
// #else
// auto file = rw.getSourceMgr().getFilename(ce->getLocEnd());
// #endif
// fmt::print("considering filename={}, regex={}\n", file.str(), Matcher);

// std::regex re(Matcher);
// std::cmatch m;
// if (std::regex_match(file.str().c_str(), m, re)) {
// fmt::print("=== Invoking rewriter on matched result in {} ===\n", file.str());
// // we need to process this match
// } else {
// // break out and ignore this file
// return;
// }

// if (processed_files.find(file.str()) != processed_files.end()) {
// fmt::print("already generated for filename={}\n", file.str());
// return;
// }

// new_processed_files.insert(file.str());
// }

// fmt::print(
// "Traversing function \"{}\" ptr={}\n",
// ce->getDirectCallee()->getNameInfo().getAsString(),
// static_cast<void const*>(ce)
// );
// ce->dumpPretty(ce->getDirectCallee()->getASTContext());
// fmt::print("\n");
// ce->dumpColor();

// #if LLVM_VERSION_MAJOR > 7
// auto loc = ce->getEndLoc();
// #else
// auto loc = ce->getLocEnd();
// #endif

// int offset = Lexer::MeasureTokenLength(loc, rw.getSourceMgr(), rw.getLangOpts()) + 1;

// SourceLocation loc2 = loc.getLocWithOffset(offset);
// rw.InsertText(loc2, "\nKokkos::fence();", true, true);
// } else {
// fmt::print(stderr, "traversing something unknown?\n");
// exit(1);
// }
}
}

Expand Down

0 comments on commit 520e685

Please sign in to comment.