Skip to content

Commit

Permalink
feat(//core/lowering): Adding two passes, one to delimit and one to mark
Browse files Browse the repository at this point in the history
ops to fallback

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Aug 20, 2021
1 parent ad07645 commit 2e04ce5
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 0 deletions.
1 change: 1 addition & 0 deletions core/lowering/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ cc_library(
"exception_elimination.cpp",
"fuse_addmm_branches.cpp",
"linear_to_addmm.cpp",
"module_fallback.cpp",
"op_aliasing.cpp",
"reduce_to.cpp",
"remove_bn_dim_check.cpp",
Expand Down
103 changes: 103 additions & 0 deletions core/lowering/passes/module_fallback.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#include <stack>
#include <unordered_set>

#include "core/lowering/passes/passes.h"
#include "core/util/prelude.h"

namespace trtorch {
namespace core {
namespace lowering {
namespace passes {

std::string unmangle_cls_name(const std::string& name) {
auto unmangled = name;

std::size_t torch_prefix = unmangled.find("__torch__");
if (torch_prefix != std::string::npos) {
unmangled.erase(torch_prefix, 10);
}

std::size_t mangle_pos = unmangled.find("___torch_mangle_");
if (mangle_pos != std::string::npos) {
unmangled.erase(mangle_pos, 21);
}
return unmangled;
}

void NotateModuleForFallback(const torch::jit::Module& mod, std::string mod_name, const std::string& method_name, std::unordered_set<std::string> forced_fallback_modules) {
auto cls_name = unmangle_cls_name(mod.type()->name()->qualifiedName());
auto g = mod.get_method(method_name).graph();

auto nodes = g->block()->nodes();
bool changed_mod = false;
for (const auto n : nodes) {
if (n->kind() == torch::jit::prim::GetAttr) {
auto out_type = unmangle_cls_name(c10::toString(n->output(0)->type()));
if (forced_fallback_modules.find(out_type) != forced_fallback_modules.end()) {
LOG_DEBUG("Marking module for fallback: " << n->s(c10::attr::name) << " (" << out_type << ") [owner: " << mod_name << " (" << cls_name << ")]");
auto uses = n->output(0)->uses();
for (const auto u : uses) {
auto user = u.user;
auto delim_start_n = g->create(torch::jit::prim::Enter, 0);
delim_start_n->s_(c10::Symbol::attr("compilation_edge"), "start");
auto num_end_outs = user->outputs().size();
auto delim_end_n = g->create(torch::jit::prim::Exit, 0);
delim_end_n->s_(c10::Symbol::attr("compilation_edge"), "end");
delim_start_n->insertBefore(user);
delim_end_n->insertAfter(user);
}
changed_mod = true;
}
}
}

if (changed_mod) {
LOG_DEBUG(*g);
}

for (const auto sub_mod : mod.named_children()) {
NotateModuleForFallback(sub_mod.value, sub_mod.name, method_name, forced_fallback_modules);
}
}

void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g) {
auto b = g->block();

std::stack<bool> mark = std::stack<bool>({false});
for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
auto n = *it;
if (!mark.top() && n->kind() == torch::jit::prim::Enter && n->hasAttributeS("compilation_edge")) {
if (n->s(c10::Symbol::attr("compilation_edge")) == "start") {
LOG_DEBUG("Starting to mark new segmented targeted for torch");
mark.push(true);
it.destroyCurrent();
}
} else if (mark.top() && n->kind() == torch::jit::prim::Enter && n->hasAttributeS("compilation_edge")) {
if(n->s(c10::Symbol::attr("compilation_edge")) == "start") {
LOG_DEBUG("Found the start of another segmented block targeted for torch while actively marking a block");
mark.push(true);
it.destroyCurrent();
}
} else if (mark.top() && n->kind() == torch::jit::prim::Exit && n->hasAttributeS("compilation_edge")) {
if(n->s(c10::Symbol::attr("compilation_edge")) == "end") {
LOG_DEBUG("Found the end of segmented block targeted for torch while actively marking a block");
mark.pop();
it.destroyCurrent();
}
} else if (!mark.top() && n->kind() == torch::jit::prim::Exit && n->hasAttributeS("compilation_edge")) {
if(n->s(c10::Symbol::attr("compilation_edge")) == "end") {
LOG_WARNING("Found the end of segmented block targeted for torch while not actively marking a block");
}
} else if (mark.top()) {
LOG_GRAPH("Marking " << util::node_info(n) << " to run in PyTorch");
n->i_(c10::Symbol::attr("to_compile"), (int64_t) false);
}
}

LOG_GRAPH("Post marking ops for pytorch execution: " << *g);
}

} // Namespace passes
} // namespace lowering
} // namespace core
} // namespace trtorch
3 changes: 3 additions & 0 deletions core/lowering/passes/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@ namespace core {
namespace lowering {
namespace passes {

void NotateModuleForFallback(const torch::jit::Module& mod, std::string mod_name, const std::string& method_name, std::unordered_set<std::string> forced_fallback_modules);

void Conv2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
void FuseAddMMBranches(std::shared_ptr<torch::jit::Graph> graph);
void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph);
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);
void ReduceToOperation(std::shared_ptr<torch::jit::Graph>& graph);
void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g);
void RemoveBNDimCheck(std::shared_ptr<torch::jit::Graph> graph);
void RemoveContiguous(std::shared_ptr<torch::jit::Graph>& graph);
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);
Expand Down
1 change: 1 addition & 0 deletions core/lowering/register_trt_placeholder_ops.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "torch/csrc/jit/runtime/custom_operator.h"
#include "torch/library.h"

namespace torch {
namespace jit {
Expand Down

0 comments on commit 2e04ce5

Please sign in to comment.