Skip to content

Commit

Permalink
fix(//core/lowering): Fixes module level fallback recursion
Browse files Browse the repository at this point in the history
This commit fixes module level fallback by using method calls
to determine modules to recurse down too. This should be robust
to names other than forward used for methods as well as ignoring
functional modules.

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Oct 1, 2021
1 parent 722aa94 commit 2fc612d
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 15 deletions.
21 changes: 11 additions & 10 deletions core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
torch::jit::EliminateCommonSubexpression(g);
}
torch::jit::EliminateDeadCode(g);
passes::MarkNodesForFallback(g, true);
if (lower_info.forced_fallback_modules.size() > 0) {
passes::MarkNodesForFallback(g, true);
}
passes::UnpackHardSwish(g);
passes::EliminateExceptionOrPassPattern(g);
passes::ReduceToOperation(g);
Expand All @@ -60,12 +62,13 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
LOG_GRAPH(*g);
}

torch::jit::Module LowerModule(
const torch::jit::Module& mod,
std::string method_name,
std::unordered_set<std::string> forced_fallback_modules) {
passes::NotateModuleForFallback(mod, "", method_name, forced_fallback_modules);
LOG_GRAPH("After MLF notation pass: " << *mod.get_method(method_name).graph());
torch::jit::Module LowerModule(const torch::jit::Module& mod, std::string method_name, const LowerInfo& lower_info) {
std::unordered_set<std::string> forced_fallback_modules(
lower_info.forced_fallback_modules.begin(), lower_info.forced_fallback_modules.end());
if (forced_fallback_modules.size() > 0) {
passes::NotateModuleForFallback(mod, "", method_name, forced_fallback_modules);
LOG_GRAPH("After MLF notation pass: " << *mod.get_method(method_name).graph());
}
auto mod_ = torch::jit::freeze_module(mod);
LOG_GRAPH("After freeze: " << *mod_.get_method(method_name).graph());
return mod_;
Expand All @@ -77,9 +80,7 @@ std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> L
const LowerInfo& lower_info) {
LOG_DEBUG(lower_info);
LOG_GRAPH("Before lowering: " << *mod.get_method(method_name).graph());
std::unordered_set<std::string> forced_fallback_modules(
lower_info.forced_fallback_modules.begin(), lower_info.forced_fallback_modules.end());
auto lowered_mod = lower_info.unfreeze_module ? mod : LowerModule(mod, method_name, forced_fallback_modules);
auto lowered_mod = lower_info.unfreeze_module ? mod : LowerModule(mod, method_name, lower_info);
auto g = lowered_mod.get_method(method_name).graph();

LOG_GRAPH("LibTorch Lowering");
Expand Down
31 changes: 26 additions & 5 deletions core/lowering/passes/module_fallback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,29 @@ void NotateModuleForFallback(
LOG_GRAPH("Notated graph: " << *g);
}

for (const auto sub_mod : mod.named_children()) {
NotateModuleForFallback(sub_mod.value, sub_mod.name, method_name, forced_fallback_modules);
if (mod.named_children().size() > 0) {
for (const auto n : nodes) {
std::string sub_method_name = "";
if (n->kind() == torch::jit::prim::CallMethod) {
sub_method_name = n->s(c10::Symbol::attr("name"));
auto sub_mod_val = n->input(0);
auto sub_mod_src_n = sub_mod_val->node();
if (!sub_mod_src_n->hasAttributeS("name")) {
LOG_GRAPH("Node: " << util::node_info(sub_mod_src_n) << " manages a module with no name, skipping");
break;
}
auto sub_mod_name = sub_mod_src_n->s(c10::Symbol::attr("name"));
for (const auto sub_mod : mod.named_children()) {
// Theres probably a way to directly access the module we care about
if (sub_mod.name == sub_mod_name) {
LOG_GRAPH(
"Looking at <module>.<method>() next: " << sub_mod_name << "." << sub_method_name
<< "() (lowering.passes.NotateModuleForFallback)");
NotateModuleForFallback(sub_mod.value, sub_mod.name, sub_method_name, forced_fallback_modules);
}
}
}
}
}
}

Expand All @@ -74,23 +95,23 @@ void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g, bool delete_del
auto n = *it;
if (!mark.top() && n->kind() == torch::jit::prim::Enter && n->hasAttributeS("compilation_edge")) {
if (n->s(c10::Symbol::attr("compilation_edge")) == "start") {
LOG_DEBUG("Starting to mark new segmented block targeted for torch");
LOG_GRAPH("Starting to mark new segmented block targeted for torch");
mark.push(true);
if (delete_delims) {
it.destroyCurrent();
}
}
} else if (mark.top() && n->kind() == torch::jit::prim::Enter && n->hasAttributeS("compilation_edge")) {
if (n->s(c10::Symbol::attr("compilation_edge")) == "start") {
LOG_DEBUG("Found the start of another segmented block targeted for torch while actively marking a block");
LOG_GRAPH("Found the start of another segmented block targeted for torch while actively marking a block");
mark.push(true);
if (delete_delims) {
it.destroyCurrent();
}
}
} else if (mark.top() && n->kind() == torch::jit::prim::Exit && n->hasAttributeS("compilation_edge")) {
if (n->s(c10::Symbol::attr("compilation_edge")) == "end") {
LOG_DEBUG("Found the end of segmented block targeted for torch while actively marking a block");
LOG_GRAPH("Found the end of segmented block targeted for torch while actively marking a block");
mark.pop();
if (delete_delims) {
it.destroyCurrent();
Expand Down

0 comments on commit 2fc612d

Please sign in to comment.