From f94ae8f4285b821cb8150e26dcc41bba060ce081 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Fri, 3 Sep 2021 14:47:37 -0700 Subject: [PATCH] fix(//core/lowering): Fixes module level fallback recursion This commit fixes module level fallback by using method calls to determine modules to recurse down too. This should be robust to names other than forward used for methods as well as ignoring functional modules. Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- core/lowering/lowering.cpp | 21 +++++++------- core/lowering/passes/module_fallback.cpp | 37 +++++++++++++++++++----- 2 files changed, 40 insertions(+), 18 deletions(-) diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index 506a4934fe..4be3e403aa 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -37,7 +37,9 @@ void LowerGraph(std::shared_ptr& g, LowerInfo lower_info) { torch::jit::EliminateCommonSubexpression(g); } torch::jit::EliminateDeadCode(g); - passes::MarkNodesForFallback(g, true); + if (lower_info.forced_fallback_modules.size() > 0) { + passes::MarkNodesForFallback(g, true); + } passes::UnpackHardSwish(g); passes::EliminateExceptionOrPassPattern(g); passes::ReduceToOperation(g); @@ -60,12 +62,13 @@ void LowerGraph(std::shared_ptr& g, LowerInfo lower_info) { LOG_GRAPH(*g); } -torch::jit::Module LowerModule( - const torch::jit::Module& mod, - std::string method_name, - std::unordered_set forced_fallback_modules) { - passes::NotateModuleForFallback(mod, "", method_name, forced_fallback_modules); - LOG_GRAPH("After MLF notation pass: " << *mod.get_method(method_name).graph()); +torch::jit::Module LowerModule(const torch::jit::Module& mod, std::string method_name, const LowerInfo& lower_info) { + std::unordered_set forced_fallback_modules( + lower_info.forced_fallback_modules.begin(), lower_info.forced_fallback_modules.end()); + if (forced_fallback_modules.size() > 0) { + passes::NotateModuleForFallback(mod, "", method_name, forced_fallback_modules); + LOG_GRAPH("After MLF notation pass: " << *mod.get_method(method_name).graph()); + } auto mod_ = torch::jit::freeze_module(mod); LOG_GRAPH("After freeze: " << *mod_.get_method(method_name).graph()); return mod_; @@ -77,9 +80,7 @@ std::pair, std::vector> L const LowerInfo& lower_info) { LOG_DEBUG(lower_info); LOG_GRAPH("Before lowering: " << *mod.get_method(method_name).graph()); - std::unordered_set forced_fallback_modules( - lower_info.forced_fallback_modules.begin(), lower_info.forced_fallback_modules.end()); - auto lowered_mod = lower_info.unfreeze_module ? mod : LowerModule(mod, method_name, forced_fallback_modules); + auto lowered_mod = lower_info.unfreeze_module ? mod : LowerModule(mod, method_name, lower_info); auto g = lowered_mod.get_method(method_name).graph(); LOG_GRAPH("LibTorch Lowering"); diff --git a/core/lowering/passes/module_fallback.cpp b/core/lowering/passes/module_fallback.cpp index 9061130f4e..be7f7497b5 100644 --- a/core/lowering/passes/module_fallback.cpp +++ b/core/lowering/passes/module_fallback.cpp @@ -39,7 +39,7 @@ void NotateModuleForFallback( if (n->kind() == torch::jit::prim::GetAttr) { auto out_type = unmangle_cls_name(c10::toString(n->output(0)->type())); if (forced_fallback_modules.find(out_type) != forced_fallback_modules.end()) { - LOG_DEBUG( + LOG_GRAPH( "Notating module for fallback: " << n->s(c10::attr::name) << " (" << out_type << ") [owner: " << mod_name << " (" << cls_name << ")]"); auto uses = n->output(0)->uses(); @@ -58,11 +58,32 @@ void NotateModuleForFallback( } if (changed_mod) { - LOG_DEBUG("Notated graph: " << *g); + LOG_GRAPH("Notated graph: " << *g); } - for (const auto sub_mod : mod.named_children()) { - NotateModuleForFallback(sub_mod.value, sub_mod.name, method_name, forced_fallback_modules); + if (mod.named_children().size() > 0) { + for (const auto n : nodes) { + std::string sub_method_name = ""; + if (n->kind() == torch::jit::prim::CallMethod) { + sub_method_name = n->s(c10::Symbol::attr("name")); + auto sub_mod_val = n->input(0); + auto sub_mod_src_n = sub_mod_val->node(); + if (!sub_mod_src_n->hasAttributeS("name")) { + LOG_GRAPH("Node: " << util::node_info(sub_mod_src_n) << " manages a module with no name, skipping"); + break; + } + auto sub_mod_name = sub_mod_src_n->s(c10::Symbol::attr("name")); + for (const auto sub_mod : mod.named_children()) { + // Theres probably a way to directly access the module we care about + if (sub_mod.name == sub_mod_name) { + LOG_GRAPH( + "Looking at .() next: " << sub_mod_name << "." << sub_method_name + << "() (lowering.passes.NotateModuleForFallback)"); + NotateModuleForFallback(sub_mod.value, sub_mod.name, sub_method_name, forced_fallback_modules); + } + } + } + } } } @@ -74,7 +95,7 @@ void MarkNodesForFallback(std::shared_ptr& g, bool delete_del auto n = *it; if (!mark.top() && n->kind() == torch::jit::prim::Enter && n->hasAttributeS("compilation_edge")) { if (n->s(c10::Symbol::attr("compilation_edge")) == "start") { - LOG_DEBUG("Starting to mark new segmented block targeted for torch"); + LOG_GRAPH("Starting to mark new segmented block targeted for torch"); mark.push(true); if (delete_delims) { it.destroyCurrent(); @@ -82,7 +103,7 @@ void MarkNodesForFallback(std::shared_ptr& g, bool delete_del } } else if (mark.top() && n->kind() == torch::jit::prim::Enter && n->hasAttributeS("compilation_edge")) { if (n->s(c10::Symbol::attr("compilation_edge")) == "start") { - LOG_DEBUG("Found the start of another segmented block targeted for torch while actively marking a block"); + LOG_GRAPH("Found the start of another segmented block targeted for torch while actively marking a block"); mark.push(true); if (delete_delims) { it.destroyCurrent(); @@ -90,7 +111,7 @@ void MarkNodesForFallback(std::shared_ptr& g, bool delete_del } } else if (mark.top() && n->kind() == torch::jit::prim::Exit && n->hasAttributeS("compilation_edge")) { if (n->s(c10::Symbol::attr("compilation_edge")) == "end") { - LOG_DEBUG("Found the end of segmented block targeted for torch while actively marking a block"); + LOG_GRAPH("Found the end of segmented block targeted for torch while actively marking a block"); mark.pop(); if (delete_delims) { it.destroyCurrent(); @@ -106,7 +127,7 @@ void MarkNodesForFallback(std::shared_ptr& g, bool delete_del } } - LOG_DEBUG("After marking operations for torch fallback: " << *g); + LOG_GRAPH("After marking operations for torch fallback: " << *g); } } // namespace passes