Skip to content

Commit

Permalink
#1: finish transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
lifflander committed Feb 16, 2021
1 parent 520e685 commit 1ed3bb5
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 64 deletions.
44 changes: 35 additions & 9 deletions src/test.cc
Original file line number Diff line number Diff line change
@@ -1,39 +1,65 @@

#include <cstddef>

namespace std {
struct string {
string(char const* str) {}
};
};

namespace Kokkos {

void parallel_for(int i, int j) { }
void parallel_scan(int i, int j) { }
template <typename... A>
class RangePolicy {
public:
explicit RangePolicy(size_t in) {}
};

template <typename Policy>
void parallel_for(std::string str, Policy rp) { }

void parallel_for(std::string str, size_t x) { }
void parallel_for(size_t x) { }
void parallel_scan(std::string str, size_t j) { }
void parallel_scan(size_t i) { }
void fence() { }

}

void test3() {
Kokkos::parallel_for("test-123", Kokkos::RangePolicy<int, float>(1000));
}

void test() {
using namespace Kokkos;
//using namespace Kokkos;
while (true) {
int x = 10;
if (true) x=5;
parallel_for(80, 90);
fence();
Kokkos::parallel_for("80", 90);
Kokkos::fence();
if (false) x=6;
}
}

void test2() {
if (true) {
Kokkos::parallel_for(10, 20);
Kokkos::parallel_for("10", 20);
Kokkos::fence();
Kokkos::parallel_scan(10, 20);
Kokkos::parallel_scan("10", 20);
Kokkos::fence();
}
if (true) {
Kokkos::parallel_for(30, 40);
Kokkos::parallel_for(40);
Kokkos::fence();
}
if (true) {
Kokkos::parallel_for(1219, 23);
Kokkos::parallel_for("1219", 23);
test();
}

if (true) {
Kokkos::fence();
}
}

int main2() {
Expand Down
158 changes: 103 additions & 55 deletions src/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ static cl::opt<std::string> Matcher("f", cl::desc("Only transform filenames matc

static cl::list<std::string> Includes("I", cl::desc("Include directories"), cl::ZeroOrMore);

static cl::opt<bool> GenerateInline("inline", cl::desc("Generate code inline and modify files"));
static cl::opt<bool> DoFencesOnly("fence-only", cl::desc("Rewrite fences only"));

static std::set<std::string> new_processed_files;
static std::set<std::string> processed_files;
Expand Down Expand Up @@ -130,6 +130,77 @@ auto getEnd(T const& t) {
#endif
}

struct RewriteArgument {
explicit RewriteArgument(Rewriter& in_rw)
: rw(in_rw)
{ }

void operator()(CallExpr const* par) const {
bool first_is_string = false;
auto arg_iter = par->arg_begin();
auto const& first_arg = *arg_iter;
QualType const& type = first_arg->getType();
// fmt::print("type={}\n", type.getAsString());
if (type.getAsString() == "std::string") {
// fmt::print("first argument is string\n");
first_is_string = true;
arg_iter++;
}
auto const& policy = *arg_iter;
QualType const& policy_type = policy->getType();
fmt::print("policy=\"{}\"\n", policy_type.getAsString());
if (
policy_type.getAsString() == "size_t" or
policy_type.getAsString() == "std::size_t" or
policy_type.getAsString() == "int"
) {
fmt::print("range is a basic policy={}\n", policy_type.getAsString());
} else {
// advanced policy must lift to builder
fmt::print("range is advanced policy\n");
//policy_type.dump();
auto str = policy_type.getAsString();
if (str.substr(0, 6) == "class ") {
auto t = str.substr(6, str.size()-1);
std::string main_type = "";
std::string temp_type = "";
for (std::size_t i = 0; i < t.size(); i++) {
if (t[i] == '<') {
main_type = t.substr(0, i);
temp_type = t.substr(i+1, t.size()-i-2);
}
}

if (main_type == "") {
main_type = t;
}

fmt::print("MAIN: {}\n", main_type);
fmt::print("TEMP: {}\n", temp_type);

auto begin = getBegin(policy);
//auto end = getEnd(policy->getType());

std::string new_type = "";
if (temp_type == "") {
new_type = fmt::format("buildKokkosPolicy<{}>", main_type);
} else {
new_type = fmt::format("buildKokkosPolicy<{}, {}>", main_type, temp_type);
}
fmt::print("new type={}\n", new_type);
rw.ReplaceText(begin, str.size()-6, new_type);

} else {
fmt::print("failure to parse template\n");
exit(1);
}
}
}

private:
Rewriter& rw;
};

struct RewriteBlocking {
explicit RewriteBlocking(Rewriter& in_rw)
: rw(in_rw)
Expand All @@ -145,7 +216,7 @@ struct RewriteBlocking {
if (str == "Kokkos::") {
rw.ReplaceText(begin, 6, "empire");
} else {
rw.InsertTextBefore(begin, "empire::");
//rw.InsertTextBefore(begin, "empire::");
}
}

Expand Down Expand Up @@ -205,6 +276,16 @@ struct ParallelForRewriter : MatchFinder::MatchCallback {
{ }

virtual void run(const MatchFinder::MatchResult &Result) {
if (CallExpr const *ce = Result.Nodes.getNodeAs<clang::CallExpr>("fenceExpr")) {
auto begin = getBegin(ce);

auto str = rw.getRewrittenText(SourceRange(begin, begin.getLocWithOffset(6)));
fmt::print("str={}\n", str);
if (str == "Kokkos::") {
rw.ReplaceText(begin, 6, "empire");
}
}

if (CallExpr const *ce = Result.Nodes.getNodeAs<clang::CallExpr>("callExpr")) {

auto& ctx = Result.Context;
Expand All @@ -215,15 +296,25 @@ struct ParallelForRewriter : MatchFinder::MatchCallback {
return;
}

ExprWithCleanups const* ewc = nullptr;
if (isa<ExprWithCleanups>(p[0].get<Stmt>())) {
//fmt::print("isa ExprWithCleanups\n");
ewc = cast<ExprWithCleanups>(p[0].get<Stmt>());
p = ctx->getParents(*ewc);
}

bool found_fence = false;
CallExpr const* fence = nullptr;
for (std::size_t i = 0; i < p.size(); i++) {
Stmt const* st = p[i].get<Stmt>();
//st->dumpColor();
if (isa<CompoundStmt>(st)) {
auto const& cs = cast<CompoundStmt>(st);
if (cs->size() == 1) {
break;
}
for (auto iter = cs->child_begin(); iter != cs->child_end(); ++iter) {
if (*iter == ce) {
if (*iter == ce or *iter == ewc) {
iter++;
if (iter != cs->child_end()) {
MatchFinder fence_matcher;
Expand All @@ -247,6 +338,9 @@ struct ParallelForRewriter : MatchFinder::MatchCallback {
break;
}

auto rr = std::make_unique<RewriteArgument>(rw);
rr->operator()(ce);

if (found_fence) {
// rewrite to blocking with empire
auto rb = std::make_unique<RewriteBlocking>(rw);
Expand All @@ -256,57 +350,6 @@ struct ParallelForRewriter : MatchFinder::MatchCallback {
auto ra = std::make_unique<RewriteAsync>(rw);
ra->operator()(ce);
}

// // Only match files based on user's input
// if (Matcher != "") {
// #if LLVM_VERSION_MAJOR > 7
// auto file = rw.getSourceMgr().getFilename(ce->getEndLoc());
// #else
// auto file = rw.getSourceMgr().getFilename(ce->getLocEnd());
// #endif
// fmt::print("considering filename={}, regex={}\n", file.str(), Matcher);

// std::regex re(Matcher);
// std::cmatch m;
// if (std::regex_match(file.str().c_str(), m, re)) {
// fmt::print("=== Invoking rewriter on matched result in {} ===\n", file.str());
// // we need to process this match
// } else {
// // break out and ignore this file
// return;
// }

// if (processed_files.find(file.str()) != processed_files.end()) {
// fmt::print("already generated for filename={}\n", file.str());
// return;
// }

// new_processed_files.insert(file.str());
// }

// fmt::print(
// "Traversing function \"{}\" ptr={}\n",
// ce->getDirectCallee()->getNameInfo().getAsString(),
// static_cast<void const*>(ce)
// );
// ce->dumpPretty(ce->getDirectCallee()->getASTContext());
// fmt::print("\n");
// ce->dumpColor();

// #if LLVM_VERSION_MAJOR > 7
// auto loc = ce->getEndLoc();
// #else
// auto loc = ce->getLocEnd();
// #endif

// int offset = Lexer::MeasureTokenLength(loc, rw.getSourceMgr(), rw.getLangOpts()) + 1;

// SourceLocation loc2 = loc.getLocWithOffset(offset);
// rw.InsertText(loc2, "\nKokkos::fence();", true, true);
// } else {
// fmt::print(stderr, "traversing something unknown?\n");
// exit(1);
// }
}
}

Expand All @@ -315,7 +358,12 @@ struct ParallelForRewriter : MatchFinder::MatchCallback {

struct MyASTConsumer : ASTConsumer {
MyASTConsumer(Rewriter& in_rw) : call_handler_(in_rw) {
matcher_.addMatcher(CallMatcher, &call_handler_);

if (DoFencesOnly) {
matcher_.addMatcher(FenceMatcher, &call_handler_);
} else {
matcher_.addMatcher(CallMatcher, &call_handler_);
}
}

void HandleTranslationUnit(ASTContext& Context) override {
Expand Down

0 comments on commit 1ed3bb5

Please sign in to comment.