diff --git a/src/test.cc b/src/test.cc index 0db4b91..27f33b5 100644 --- a/src/test.cc +++ b/src/test.cc @@ -3,6 +3,7 @@ namespace Kokkos { void parallel_for(int i, int j) { } void parallel_scan(int i, int j) { } +void fence() { } } @@ -12,7 +13,8 @@ void test() { while (true) { int x = 10; if (true) x=5; - parallel_for(10, 20); + parallel_for(80, 90); + fence(); if (false) x=6; } } @@ -20,10 +22,17 @@ void test() { void test2() { if (true) { Kokkos::parallel_for(10, 20); + Kokkos::fence(); Kokkos::parallel_scan(10, 20); + Kokkos::fence(); } if (true) { Kokkos::parallel_for(30, 40); + Kokkos::fence(); + } + if (true) { + Kokkos::parallel_for(1219, 23); + test(); } } diff --git a/src/transform.cc b/src/transform.cc index a332fb3..0e9fda7 100644 --- a/src/transform.cc +++ b/src/transform.cc @@ -94,6 +94,111 @@ StatementMatcher CallMatcher = ) ).bind("callExpr"); +StatementMatcher FenceMatcher = + callExpr( + callee( + functionDecl( + hasName("::Kokkos::fence") + ) + ) + ).bind("fenceExpr"); + +struct FenceCallback : MatchFinder::MatchCallback { + virtual void run(const MatchFinder::MatchResult &Result) { + if (CallExpr const *ce = Result.Nodes.getNodeAs("fenceExpr")) { + found = true; + } + } + bool found = false; +}; + +template +auto getBegin(T const& t) { +#if LLVM_VERSION_MAJOR > 7 + return t->getBeginLoc(); +#else + return t->getBeginLoc(); +#endif +} + +template +auto getEnd(T const& t) { +#if LLVM_VERSION_MAJOR > 7 + return t->getEndLoc(); +#else + return t->getLocEnd(); +#endif +} + +struct RewriteBlocking { + explicit RewriteBlocking(Rewriter& in_rw) + : rw(in_rw) + { } + + void operator()(CallExpr const* par, CallExpr const* fence) const { + // Change the namespace + { + auto begin = getBegin(par); + + auto str = rw.getRewrittenText(SourceRange(begin, begin.getLocWithOffset(6))); + fmt::print("str={}\n", str); + if (str == "Kokkos::") { + rw.ReplaceText(begin, 6, "empire"); + } else { + rw.InsertTextBefore(begin, "empire::"); + } + } + + // Switch to parallel_* blocking + { + auto end = getEnd(par->getCallee()); + rw.InsertTextAfterToken(end, "_blocking"); + } + + // Remove the fence line + { + auto begin = getBegin(fence); + auto end = getEnd(fence); + typename Rewriter::RewriteOptions ro; + ro.RemoveLineIfEmpty = true; + rw.RemoveText(SourceRange{begin, end.getLocWithOffset(1)}, ro); + } + } + +private: + Rewriter& rw; +}; + +struct RewriteAsync { + explicit RewriteAsync(Rewriter& in_rw) + : rw(in_rw) + { } + + void operator()(CallExpr const* par) const { + // Change the namespace + { + auto begin = getBegin(par); + + auto str = rw.getRewrittenText(SourceRange(begin, begin.getLocWithOffset(6))); + fmt::print("str={}\n", str); + if (str == "Kokkos::") { + rw.ReplaceText(begin, 6, "empire"); + } else { + rw.InsertTextBefore(begin, "empire::"); + } + } + + // Switch to parallel_* async + { + auto end = getEnd(par->getCallee()); + rw.InsertTextAfterToken(end, "_async"); + } + } + +private: + Rewriter& rw; +}; + struct ParallelForRewriter : MatchFinder::MatchCallback { explicit ParallelForRewriter(Rewriter& in_rw) : rw(in_rw) @@ -102,52 +207,106 @@ struct ParallelForRewriter : MatchFinder::MatchCallback { virtual void run(const MatchFinder::MatchResult &Result) { if (CallExpr const *ce = Result.Nodes.getNodeAs("callExpr")) { - // Only match files based on user's input - if (Matcher != "") { - auto file = rw.getSourceMgr().getFilename(ce->getEndLoc()); + auto& ctx = Result.Context; + auto p = ctx->getParents(*ce); + fmt::print("size={}\n", p.size()); - fmt::print("considering filename={}, regex={}\n", file.str(), Matcher); - - std::regex re(Matcher); - std::cmatch m; - if (std::regex_match(file.str().c_str(), m, re)) { - fmt::print("=== Invoking rewriter on matched result in {} ===\n", file.str()); - // we need to process this match - } else { - // break out and ignore this file - return; - } + if (p.size() != 1) { + return; + } - if (processed_files.find(file.str()) != processed_files.end()) { - fmt::print("already generated for filename={}\n", file.str()); - return; + bool found_fence = false; + CallExpr const* fence = nullptr; + for (std::size_t i = 0; i < p.size(); i++) { + Stmt const* st = p[i].get(); + //st->dumpColor(); + if (isa(st)) { + auto const& cs = cast(st); + for (auto iter = cs->child_begin(); iter != cs->child_end(); ++iter) { + if (*iter == ce) { + iter++; + if (iter != cs->child_end()) { + MatchFinder fence_matcher; + auto fc = std::make_unique(); + fence_matcher.addMatcher(FenceMatcher, fc.get()); + fence_matcher.match(**iter, *Result.Context); + found_fence = fc->found; + if (fc->found) { + fence = cast(*iter); + ///fence = *iter; + fmt::print("FOUND fence\n"); + } else { + fmt::print("NOT FOUND fence\n"); + } + + } + } + } } - new_processed_files.insert(file.str()); + break; } - fmt::print( - "Traversing function \"{}\" ptr={}\n", - ce->getDirectCallee()->getNameInfo().getAsString(), - static_cast(ce) - ); - ce->dumpPretty(ce->getDirectCallee()->getASTContext()); - fmt::print("\n"); - ce->dumpColor(); - -#if CLANG5 - auto loc = ce->getLocEnd(); -#else - auto loc = ce->getEndLoc(); -#endif - - int offset = Lexer::MeasureTokenLength(loc, rw.getSourceMgr(), rw.getLangOpts()) + 1; + if (found_fence) { + // rewrite to blocking with empire + auto rb = std::make_unique(rw); + rb->operator()(ce, fence); + } else { + // rewrite to async with empire + auto ra = std::make_unique(rw); + ra->operator()(ce); + } - SourceLocation loc2 = loc.getLocWithOffset(offset); - rw.InsertText(loc2, "\nKokkos::fence();", true, true); - } else { - fmt::print(stderr, "traversing something unknown?\n"); - exit(1); +// // Only match files based on user's input +// if (Matcher != "") { +// #if LLVM_VERSION_MAJOR > 7 +// auto file = rw.getSourceMgr().getFilename(ce->getEndLoc()); +// #else +// auto file = rw.getSourceMgr().getFilename(ce->getLocEnd()); +// #endif +// fmt::print("considering filename={}, regex={}\n", file.str(), Matcher); + +// std::regex re(Matcher); +// std::cmatch m; +// if (std::regex_match(file.str().c_str(), m, re)) { +// fmt::print("=== Invoking rewriter on matched result in {} ===\n", file.str()); +// // we need to process this match +// } else { +// // break out and ignore this file +// return; +// } + +// if (processed_files.find(file.str()) != processed_files.end()) { +// fmt::print("already generated for filename={}\n", file.str()); +// return; +// } + +// new_processed_files.insert(file.str()); +// } + +// fmt::print( +// "Traversing function \"{}\" ptr={}\n", +// ce->getDirectCallee()->getNameInfo().getAsString(), +// static_cast(ce) +// ); +// ce->dumpPretty(ce->getDirectCallee()->getASTContext()); +// fmt::print("\n"); +// ce->dumpColor(); + +// #if LLVM_VERSION_MAJOR > 7 +// auto loc = ce->getEndLoc(); +// #else +// auto loc = ce->getLocEnd(); +// #endif + +// int offset = Lexer::MeasureTokenLength(loc, rw.getSourceMgr(), rw.getLangOpts()) + 1; + +// SourceLocation loc2 = loc.getLocWithOffset(offset); +// rw.InsertText(loc2, "\nKokkos::fence();", true, true); +// } else { +// fmt::print(stderr, "traversing something unknown?\n"); +// exit(1); +// } } }