Skip to content

Commit

Permalink
fix(//core/lowering): Fixes module level fallback recursion
Browse files Browse the repository at this point in the history
This commit fixes module level fallback by using method calls
to determine modules to recurse down too. This should be robust
to names other than forward used for methods as well as ignoring
functional modules.

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Sep 29, 2021
1 parent 0e3532b commit f94ae8f
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 18 deletions.
21 changes: 11 additions & 10 deletions core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
torch::jit::EliminateCommonSubexpression(g);
}
torch::jit::EliminateDeadCode(g);
passes::MarkNodesForFallback(g, true);
if (lower_info.forced_fallback_modules.size() > 0) {
passes::MarkNodesForFallback(g, true);
}
passes::UnpackHardSwish(g);
passes::EliminateExceptionOrPassPattern(g);
passes::ReduceToOperation(g);
Expand All @@ -60,12 +62,13 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
LOG_GRAPH(*g);
}

torch::jit::Module LowerModule(
const torch::jit::Module& mod,
std::string method_name,
std::unordered_set<std::string> forced_fallback_modules) {
passes::NotateModuleForFallback(mod, "", method_name, forced_fallback_modules);
LOG_GRAPH("After MLF notation pass: " << *mod.get_method(method_name).graph());
torch::jit::Module LowerModule(const torch::jit::Module& mod, std::string method_name, const LowerInfo& lower_info) {
std::unordered_set<std::string> forced_fallback_modules(
lower_info.forced_fallback_modules.begin(), lower_info.forced_fallback_modules.end());
if (forced_fallback_modules.size() > 0) {
passes::NotateModuleForFallback(mod, "", method_name, forced_fallback_modules);
LOG_GRAPH("After MLF notation pass: " << *mod.get_method(method_name).graph());
}
auto mod_ = torch::jit::freeze_module(mod);
LOG_GRAPH("After freeze: " << *mod_.get_method(method_name).graph());
return mod_;
Expand All @@ -77,9 +80,7 @@ std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> L
const LowerInfo& lower_info) {
LOG_DEBUG(lower_info);
LOG_GRAPH("Before lowering: " << *mod.get_method(method_name).graph());
std::unordered_set<std::string> forced_fallback_modules(
lower_info.forced_fallback_modules.begin(), lower_info.forced_fallback_modules.end());
auto lowered_mod = lower_info.unfreeze_module ? mod : LowerModule(mod, method_name, forced_fallback_modules);
auto lowered_mod = lower_info.unfreeze_module ? mod : LowerModule(mod, method_name, lower_info);
auto g = lowered_mod.get_method(method_name).graph();

LOG_GRAPH("LibTorch Lowering");
Expand Down
37 changes: 29 additions & 8 deletions core/lowering/passes/module_fallback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ void NotateModuleForFallback(
if (n->kind() == torch::jit::prim::GetAttr) {
auto out_type = unmangle_cls_name(c10::toString(n->output(0)->type()));
if (forced_fallback_modules.find(out_type) != forced_fallback_modules.end()) {
LOG_DEBUG(
LOG_GRAPH(
"Notating module for fallback: " << n->s(c10::attr::name) << " (" << out_type << ") [owner: " << mod_name
<< " (" << cls_name << ")]");
auto uses = n->output(0)->uses();
Expand All @@ -58,11 +58,32 @@ void NotateModuleForFallback(
}

if (changed_mod) {
LOG_DEBUG("Notated graph: " << *g);
LOG_GRAPH("Notated graph: " << *g);
}

for (const auto sub_mod : mod.named_children()) {
NotateModuleForFallback(sub_mod.value, sub_mod.name, method_name, forced_fallback_modules);
if (mod.named_children().size() > 0) {
for (const auto n : nodes) {
std::string sub_method_name = "";
if (n->kind() == torch::jit::prim::CallMethod) {
sub_method_name = n->s(c10::Symbol::attr("name"));
auto sub_mod_val = n->input(0);
auto sub_mod_src_n = sub_mod_val->node();
if (!sub_mod_src_n->hasAttributeS("name")) {
LOG_GRAPH("Node: " << util::node_info(sub_mod_src_n) << " manages a module with no name, skipping");
break;
}
auto sub_mod_name = sub_mod_src_n->s(c10::Symbol::attr("name"));
for (const auto sub_mod : mod.named_children()) {
// Theres probably a way to directly access the module we care about
if (sub_mod.name == sub_mod_name) {
LOG_GRAPH(
"Looking at <module>.<method>() next: " << sub_mod_name << "." << sub_method_name
<< "() (lowering.passes.NotateModuleForFallback)");
NotateModuleForFallback(sub_mod.value, sub_mod.name, sub_method_name, forced_fallback_modules);
}
}
}
}
}
}

Expand All @@ -74,23 +95,23 @@ void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g, bool delete_del
auto n = *it;
if (!mark.top() && n->kind() == torch::jit::prim::Enter && n->hasAttributeS("compilation_edge")) {
if (n->s(c10::Symbol::attr("compilation_edge")) == "start") {
LOG_DEBUG("Starting to mark new segmented block targeted for torch");
LOG_GRAPH("Starting to mark new segmented block targeted for torch");
mark.push(true);
if (delete_delims) {
it.destroyCurrent();
}
}
} else if (mark.top() && n->kind() == torch::jit::prim::Enter && n->hasAttributeS("compilation_edge")) {
if (n->s(c10::Symbol::attr("compilation_edge")) == "start") {
LOG_DEBUG("Found the start of another segmented block targeted for torch while actively marking a block");
LOG_GRAPH("Found the start of another segmented block targeted for torch while actively marking a block");
mark.push(true);
if (delete_delims) {
it.destroyCurrent();
}
}
} else if (mark.top() && n->kind() == torch::jit::prim::Exit && n->hasAttributeS("compilation_edge")) {
if (n->s(c10::Symbol::attr("compilation_edge")) == "end") {
LOG_DEBUG("Found the end of segmented block targeted for torch while actively marking a block");
LOG_GRAPH("Found the end of segmented block targeted for torch while actively marking a block");
mark.pop();
if (delete_delims) {
it.destroyCurrent();
Expand All @@ -106,7 +127,7 @@ void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g, bool delete_del
}
}

LOG_DEBUG("After marking operations for torch fallback: " << *g);
LOG_GRAPH("After marking operations for torch fallback: " << *g);
}

} // namespace passes
Expand Down

0 comments on commit f94ae8f

Please sign in to comment.