From 1ed3bb57137f233e40b255d9088db8887b9c89e3 Mon Sep 17 00:00:00 2001 From: Jonathan Lifflander Date: Mon, 15 Feb 2021 10:00:18 -0800 Subject: [PATCH] #1: finish transformation --- src/test.cc | 44 ++++++++++--- src/transform.cc | 158 ++++++++++++++++++++++++++++++----------------- 2 files changed, 138 insertions(+), 64 deletions(-) diff --git a/src/test.cc b/src/test.cc index 27f33b5..b8f055c 100644 --- a/src/test.cc +++ b/src/test.cc @@ -1,39 +1,65 @@ +#include + +namespace std { +struct string { + string(char const* str) {} +}; +}; + namespace Kokkos { -void parallel_for(int i, int j) { } -void parallel_scan(int i, int j) { } +template +class RangePolicy { +public: + explicit RangePolicy(size_t in) {} +}; + +template +void parallel_for(std::string str, Policy rp) { } + +void parallel_for(std::string str, size_t x) { } +void parallel_for(size_t x) { } +void parallel_scan(std::string str, size_t j) { } +void parallel_scan(size_t i) { } void fence() { } } +void test3() { + Kokkos::parallel_for("test-123", Kokkos::RangePolicy(1000)); +} void test() { - using namespace Kokkos; + //using namespace Kokkos; while (true) { int x = 10; if (true) x=5; - parallel_for(80, 90); - fence(); + Kokkos::parallel_for("80", 90); + Kokkos::fence(); if (false) x=6; } } void test2() { if (true) { - Kokkos::parallel_for(10, 20); + Kokkos::parallel_for("10", 20); Kokkos::fence(); - Kokkos::parallel_scan(10, 20); + Kokkos::parallel_scan("10", 20); Kokkos::fence(); } if (true) { - Kokkos::parallel_for(30, 40); + Kokkos::parallel_for(40); Kokkos::fence(); } if (true) { - Kokkos::parallel_for(1219, 23); + Kokkos::parallel_for("1219", 23); test(); } + + if (true) { + Kokkos::fence(); + } } int main2() { diff --git a/src/transform.cc b/src/transform.cc index 0e9fda7..41cdaa6 100644 --- a/src/transform.cc +++ b/src/transform.cc @@ -80,7 +80,7 @@ static cl::opt Matcher("f", cl::desc("Only transform filenames matc static cl::list Includes("I", cl::desc("Include directories"), cl::ZeroOrMore); -static cl::opt GenerateInline("inline", cl::desc("Generate code inline and modify files")); +static cl::opt DoFencesOnly("fence-only", cl::desc("Rewrite fences only")); static std::set new_processed_files; static std::set processed_files; @@ -130,6 +130,77 @@ auto getEnd(T const& t) { #endif } +struct RewriteArgument { + explicit RewriteArgument(Rewriter& in_rw) + : rw(in_rw) + { } + + void operator()(CallExpr const* par) const { + bool first_is_string = false; + auto arg_iter = par->arg_begin(); + auto const& first_arg = *arg_iter; + QualType const& type = first_arg->getType(); + // fmt::print("type={}\n", type.getAsString()); + if (type.getAsString() == "std::string") { + // fmt::print("first argument is string\n"); + first_is_string = true; + arg_iter++; + } + auto const& policy = *arg_iter; + QualType const& policy_type = policy->getType(); + fmt::print("policy=\"{}\"\n", policy_type.getAsString()); + if ( + policy_type.getAsString() == "size_t" or + policy_type.getAsString() == "std::size_t" or + policy_type.getAsString() == "int" + ) { + fmt::print("range is a basic policy={}\n", policy_type.getAsString()); + } else { + // advanced policy must lift to builder + fmt::print("range is advanced policy\n"); + //policy_type.dump(); + auto str = policy_type.getAsString(); + if (str.substr(0, 6) == "class ") { + auto t = str.substr(6, str.size()-1); + std::string main_type = ""; + std::string temp_type = ""; + for (std::size_t i = 0; i < t.size(); i++) { + if (t[i] == '<') { + main_type = t.substr(0, i); + temp_type = t.substr(i+1, t.size()-i-2); + } + } + + if (main_type == "") { + main_type = t; + } + + fmt::print("MAIN: {}\n", main_type); + fmt::print("TEMP: {}\n", temp_type); + + auto begin = getBegin(policy); + //auto end = getEnd(policy->getType()); + + std::string new_type = ""; + if (temp_type == "") { + new_type = fmt::format("buildKokkosPolicy<{}>", main_type); + } else { + new_type = fmt::format("buildKokkosPolicy<{}, {}>", main_type, temp_type); + } + fmt::print("new type={}\n", new_type); + rw.ReplaceText(begin, str.size()-6, new_type); + + } else { + fmt::print("failure to parse template\n"); + exit(1); + } + } + } + +private: + Rewriter& rw; +}; + struct RewriteBlocking { explicit RewriteBlocking(Rewriter& in_rw) : rw(in_rw) @@ -145,7 +216,7 @@ struct RewriteBlocking { if (str == "Kokkos::") { rw.ReplaceText(begin, 6, "empire"); } else { - rw.InsertTextBefore(begin, "empire::"); + //rw.InsertTextBefore(begin, "empire::"); } } @@ -205,6 +276,16 @@ struct ParallelForRewriter : MatchFinder::MatchCallback { { } virtual void run(const MatchFinder::MatchResult &Result) { + if (CallExpr const *ce = Result.Nodes.getNodeAs("fenceExpr")) { + auto begin = getBegin(ce); + + auto str = rw.getRewrittenText(SourceRange(begin, begin.getLocWithOffset(6))); + fmt::print("str={}\n", str); + if (str == "Kokkos::") { + rw.ReplaceText(begin, 6, "empire"); + } + } + if (CallExpr const *ce = Result.Nodes.getNodeAs("callExpr")) { auto& ctx = Result.Context; @@ -215,6 +296,13 @@ struct ParallelForRewriter : MatchFinder::MatchCallback { return; } + ExprWithCleanups const* ewc = nullptr; + if (isa(p[0].get())) { + //fmt::print("isa ExprWithCleanups\n"); + ewc = cast(p[0].get()); + p = ctx->getParents(*ewc); + } + bool found_fence = false; CallExpr const* fence = nullptr; for (std::size_t i = 0; i < p.size(); i++) { @@ -222,8 +310,11 @@ struct ParallelForRewriter : MatchFinder::MatchCallback { //st->dumpColor(); if (isa(st)) { auto const& cs = cast(st); + if (cs->size() == 1) { + break; + } for (auto iter = cs->child_begin(); iter != cs->child_end(); ++iter) { - if (*iter == ce) { + if (*iter == ce or *iter == ewc) { iter++; if (iter != cs->child_end()) { MatchFinder fence_matcher; @@ -247,6 +338,9 @@ struct ParallelForRewriter : MatchFinder::MatchCallback { break; } + auto rr = std::make_unique(rw); + rr->operator()(ce); + if (found_fence) { // rewrite to blocking with empire auto rb = std::make_unique(rw); @@ -256,57 +350,6 @@ struct ParallelForRewriter : MatchFinder::MatchCallback { auto ra = std::make_unique(rw); ra->operator()(ce); } - -// // Only match files based on user's input -// if (Matcher != "") { -// #if LLVM_VERSION_MAJOR > 7 -// auto file = rw.getSourceMgr().getFilename(ce->getEndLoc()); -// #else -// auto file = rw.getSourceMgr().getFilename(ce->getLocEnd()); -// #endif -// fmt::print("considering filename={}, regex={}\n", file.str(), Matcher); - -// std::regex re(Matcher); -// std::cmatch m; -// if (std::regex_match(file.str().c_str(), m, re)) { -// fmt::print("=== Invoking rewriter on matched result in {} ===\n", file.str()); -// // we need to process this match -// } else { -// // break out and ignore this file -// return; -// } - -// if (processed_files.find(file.str()) != processed_files.end()) { -// fmt::print("already generated for filename={}\n", file.str()); -// return; -// } - -// new_processed_files.insert(file.str()); -// } - -// fmt::print( -// "Traversing function \"{}\" ptr={}\n", -// ce->getDirectCallee()->getNameInfo().getAsString(), -// static_cast(ce) -// ); -// ce->dumpPretty(ce->getDirectCallee()->getASTContext()); -// fmt::print("\n"); -// ce->dumpColor(); - -// #if LLVM_VERSION_MAJOR > 7 -// auto loc = ce->getEndLoc(); -// #else -// auto loc = ce->getLocEnd(); -// #endif - -// int offset = Lexer::MeasureTokenLength(loc, rw.getSourceMgr(), rw.getLangOpts()) + 1; - -// SourceLocation loc2 = loc.getLocWithOffset(offset); -// rw.InsertText(loc2, "\nKokkos::fence();", true, true); -// } else { -// fmt::print(stderr, "traversing something unknown?\n"); -// exit(1); -// } } } @@ -315,7 +358,12 @@ struct ParallelForRewriter : MatchFinder::MatchCallback { struct MyASTConsumer : ASTConsumer { MyASTConsumer(Rewriter& in_rw) : call_handler_(in_rw) { - matcher_.addMatcher(CallMatcher, &call_handler_); + + if (DoFencesOnly) { + matcher_.addMatcher(FenceMatcher, &call_handler_); + } else { + matcher_.addMatcher(CallMatcher, &call_handler_); + } } void HandleTranslationUnit(ASTContext& Context) override {