Skip to content

Commit

Permalink
fix: Resolve issues in exception elmination pass
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Feliz <[email protected]>
  • Loading branch information
mfeliz-cruise committed May 5, 2022
1 parent 10b55d4 commit 99cea1b
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 2 deletions.
12 changes: 10 additions & 2 deletions core/lowering/passes/exception_elimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,22 @@ struct ExceptionOrPassPatternElimination {
auto arm1_start = arm1->nodes().begin();
auto arm2_start = arm2->nodes().begin();

bool arm1_starts_with_exception = (*arm1_start)->kind() == prim::RaiseException;
bool arm2_starts_with_exception = (*arm2_start)->kind() == prim::RaiseException;

if (!arm1_starts_with_exception && !arm2_starts_with_exception) {
// Neither arm matches the pattern
return false;
}

/// Check if this Node hosts a pattern like so:
/// = prim::If(%5958)
/// block0():
/// = prim::RaiseException(%45)
/// -> ()
/// block1():
/// -> ()
if ((*arm1_start)->kind() == prim::RaiseException) {
if (arm1_starts_with_exception) {
if ((*(++arm1_start))->kind() != prim::Return) {
// Make sure that block0 is solely just the exception and the return
return false;
Expand All @@ -67,7 +75,7 @@ struct ExceptionOrPassPatternElimination {
/// block1():
/// = prim::RaiseException(%45)
/// -> ()
if ((*arm2_start)->kind() == prim::RaiseException) {
if (arm2_starts_with_exception) {
if ((*(++arm2_start))->kind() != prim::Return) {
// Make sure that block1 is solely just the exception and the return
return false;
Expand Down
5 changes: 5 additions & 0 deletions tests/core/lowering/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ lowering_test(
name = "test_conv1d_pass",
)

lowering_test(
name = "test_exception_elimination_pass",
)

lowering_test(
name = "test_remove_contiguous_pass",
)
Expand Down Expand Up @@ -82,6 +86,7 @@ test_suite(
name = "lowering_tests",
tests = [
":test_conv1d_pass",
":test_exception_elimination_pass",
":test_linear_to_addmm",
":test_module_fallback_passes",
":test_operator_aliasing_pass",
Expand Down
167 changes: 167 additions & 0 deletions tests/core/lowering/test_exception_elimination_pass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
#include "core/lowering/passes/passes.h"
#include "gtest/gtest.h"
#include "torch/csrc/jit/ir/irparser.h"

TEST(LoweringPasses, EliminateExceptionOrPassPattern_Block0) {
// parseIR does not support " = prim::If(%51)" with no return value
/*std::string source_ir = R"IR(graph(%x.1 : Tensor, %y.1 : Tensor):
%3 : NoneType = prim::Constant()
%4 : int = prim::Constant[value=0]()
%mod_list.1 : Tensor[] = prim::ListConstruct(%x.1)
%47 : Tensor = aten::sum(%x.1, %3)
%49 : Tensor = aten::sum(%y.1, %3)
%50 : Tensor = aten::gt(%47, %49)
%51 : bool = aten::Bool(%50)
= prim::If(%51)
block0():
= prim::RaiseException(%45)
-> ()
block1():
-> ()
%z.1 : Tensor = aten::cat(%mod_list.1, %4)
return (%z.1))IR";*/

auto g = std::make_shared<torch::jit::Graph>();
auto x = g->insertInput(0, "x");
auto y = g->insertInput(1, "y");
torch::jit::IValue zero(0);
auto zero_const_val = g->insertConstant(zero);
auto none_const_val = g->insertConstant(torch::jit::IValue());
torch::jit::IValue except("EXCEPTION");
auto except_val = g->insertConstant(except);
auto list_node = g->createList(x->type(), torch::jit::ArrayRef<torch::jit::Value*>(x));
g->insertNode(list_node);
auto sum_x_node = g->create(torch::jit::aten::sum, {x, none_const_val});
g->insertNode(sum_x_node);
auto sum_y_node = g->create(torch::jit::aten::sum, {y, none_const_val});
g->insertNode(sum_y_node);
auto gt_node = g->create(torch::jit::aten::gt, {sum_x_node->output(), sum_y_node->output()});
g->insertNode(gt_node);
auto bool_node = g->create(torch::jit::aten::Bool, {gt_node->output()});
bool_node->output()->setType(torch::jit::BoolType::get());
g->insertNode(bool_node);
auto if_node = g->create(torch::jit::prim::If, {bool_node->output()}, 0);
auto if_block0 = if_node->addBlock();
auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val}, 0);
if_block0->appendNode(exception_node);
auto if_block1 = if_node->addBlock();
g->insertNode(if_node);
auto cat_node = g->create(torch::jit::aten::cat, {list_node->output(), zero_const_val});
g->insertNode(cat_node);
g->registerOutput(cat_node->output());

torch_tensorrt::core::lowering::passes::EliminateExceptionOrPassPattern(g);
for (auto node : g->nodes()) {
EXPECT_NE(node, if_node);
}
}

TEST(LoweringPasses, EliminateExceptionOrPassPattern_Block1) {
// parseIR does not support " = prim::If(%51)" with no return value
/*std::string source_ir = R"IR(graph(%x.1 : Tensor, %y.1 : Tensor):
%3 : NoneType = prim::Constant()
%4 : int = prim::Constant[value=0]()
%mod_list.1 : Tensor[] = prim::ListConstruct(%x.1)
%47 : Tensor = aten::sum(%x.1, %3)
%49 : Tensor = aten::sum(%y.1, %3)
%50 : Tensor = aten::gt(%47, %49)
%51 : bool = aten::Bool(%50)
= prim::If(%51)
block0():
-> ()
block1():
= prim::RaiseException(%45)
-> ()
%z.1 : Tensor = aten::cat(%mod_list.1, %4)
return (%z.1))IR";*/

auto g = std::make_shared<torch::jit::Graph>();
auto x = g->insertInput(0, "x");
auto y = g->insertInput(1, "y");
torch::jit::IValue zero(0);
auto zero_const_val = g->insertConstant(zero);
auto none_const_val = g->insertConstant(torch::jit::IValue());
torch::jit::IValue except("EXCEPTION");
auto except_val = g->insertConstant(except);
auto list_node = g->createList(x->type(), torch::jit::ArrayRef<torch::jit::Value*>(x));
g->insertNode(list_node);
auto sum_x_node = g->create(torch::jit::aten::sum, {x, none_const_val});
g->insertNode(sum_x_node);
auto sum_y_node = g->create(torch::jit::aten::sum, {y, none_const_val});
g->insertNode(sum_y_node);
auto gt_node = g->create(torch::jit::aten::gt, {sum_x_node->output(), sum_y_node->output()});
g->insertNode(gt_node);
auto bool_node = g->create(torch::jit::aten::Bool, {gt_node->output()});
bool_node->output()->setType(torch::jit::BoolType::get());
g->insertNode(bool_node);
auto if_node = g->create(torch::jit::prim::If, {bool_node->output()}, 0);
auto if_block0 = if_node->addBlock();
auto if_block1 = if_node->addBlock();
auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val}, 0);
if_block1->appendNode(exception_node);
g->insertNode(if_node);
auto cat_node = g->create(torch::jit::aten::cat, {list_node->output(), zero_const_val});
g->insertNode(cat_node);
g->registerOutput(cat_node->output());

torch_tensorrt::core::lowering::passes::EliminateExceptionOrPassPattern(g);
for (auto node : g->nodes()) {
EXPECT_NE(node, if_node);
}
}

TEST(LoweringPasses, EliminateExceptionOrPassPattern_Negative) {
// parseIR does not support " = prim::If(%51)" with no return value
/*std::string source_ir = R"IR(graph(%x.1 : Tensor, %y.1 : Tensor):
%3 : NoneType = prim::Constant()
%4 : int = prim::Constant[value=0]()
%mod_list.1 : Tensor[] = prim::ListConstruct(%x.1)
%47 : Tensor = aten::sum(%x.1, %3)
%49 : Tensor = aten::sum(%y.1, %3)
%50 : Tensor = aten::gt(%47, %49)
%51 : bool = aten::Bool(%50)
= prim::If(%51)
block0():
%10 : Tensor[] = aten::append(%mod_list.1, %y.1)
-> ()
block1():
-> ()
%z.1 : Tensor = aten::cat(%mod_list.1, %4)
return (%z.1))IR";*/

auto g = std::make_shared<torch::jit::Graph>();
auto x = g->insertInput(0, "x");
auto y = g->insertInput(1, "y");
torch::jit::IValue zero(0);
auto zero_const_val = g->insertConstant(zero);
auto none_const_val = g->insertConstant(torch::jit::IValue());
auto list_node = g->createList(x->type(), torch::jit::ArrayRef<torch::jit::Value*>(x));
g->insertNode(list_node);
auto sum_x_node = g->create(torch::jit::aten::sum, {x, none_const_val});
g->insertNode(sum_x_node);
auto sum_y_node = g->create(torch::jit::aten::sum, {y, none_const_val});
g->insertNode(sum_y_node);
auto gt_node = g->create(torch::jit::aten::gt, {sum_x_node->output(), sum_y_node->output()});
g->insertNode(gt_node);
auto bool_node = g->create(torch::jit::aten::Bool, {gt_node->output()});
bool_node->output()->setType(torch::jit::BoolType::get());
g->insertNode(bool_node);
auto if_node = g->create(torch::jit::prim::If, {bool_node->output()}, 0);
auto if_block0 = if_node->addBlock();
auto append_node = g->create(torch::jit::aten::append, {list_node->output(), y});
if_block0->appendNode(append_node);
auto if_block1 = if_node->addBlock();
g->insertNode(if_node);
auto cat_node = g->create(torch::jit::aten::cat, {list_node->output(), zero_const_val});
g->insertNode(cat_node);
g->registerOutput(cat_node->output());

torch_tensorrt::core::lowering::passes::EliminateExceptionOrPassPattern(g);
int if_count = 0;
for (auto node : g->nodes()) {
if (node == if_node) {
if_count++;
}
}
EXPECT_EQ(1, if_count);
}

0 comments on commit 99cea1b

Please sign in to comment.