Skip to content

Commit

Permalink
fix(exception_elimination): Exception branches are no longer consistent
Browse files Browse the repository at this point in the history
so cover both cases

Also adjusts Half precision thresholds in python

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Aug 2, 2021
1 parent a12d249 commit d61b667
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 15 deletions.
48 changes: 35 additions & 13 deletions core/lowering/passes/exception_elimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,6 @@ struct ExceptionOrPassPatternElimination {

private:
bool isExceptionOrPassNode(Node* n) {
/// Check if this Node hosts a pattern like so:
/// = prim::If(%5958)
/// block0():
/// = prim::RaiseException(%45)
/// -> ()
/// block1():
/// -> ()
if (n->blocks().size() != 2) {
return false;
}
Expand All @@ -46,15 +39,44 @@ struct ExceptionOrPassPatternElimination {
}

auto arm1_start = arm1->nodes().begin();
auto arm2_start = arm2->nodes().begin();

if ((*arm1_start)->kind() != prim::RaiseException && (*(++arm1_start))->kind() != prim::Return) {
// Make sure that block0 is solely just the exception and the return
return false;
/// Check if this Node hosts a pattern like so:
/// = prim::If(%5958)
/// block0():
/// = prim::RaiseException(%45)
/// -> ()
/// block1():
/// -> ()
if ((*arm1_start)->kind() == prim::RaiseException) {
if ((*(++arm1_start))->kind() != prim::Return) {
// Make sure that block0 is solely just the exception and the return
return false;
}

if ((*(arm2_start))->kind() != prim::Return) {
// Make sure that block1 is solely the return
return false;
}
}

if ((*(arm2->nodes().begin()))->kind() != prim::Return) {
// Make sure that block1 is solely the return
return false;
/// Check if this Node hosts a pattern like so:
/// = prim::If(%5958)
/// block0():
/// -> ()
/// block1():
/// = prim::RaiseException(%45)
/// -> ()
if ((*arm2_start)->kind() == prim::RaiseException) {
if ((*(++arm2_start))->kind() != prim::Return) {
// Make sure that block1 is solely just the exception and the return
return false;
}

if ((*(arm1_start))->kind() != prim::Return) {
// Make sure that block0 is solely the return
return false;
}
}

return true;
Expand Down
6 changes: 4 additions & 2 deletions tests/py/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def test_compile_script_half(self):

trt_mod = trtorch.compile(self.scripted_model, compile_spec)
same = (trt_mod(self.input.half()) - self.scripted_model(self.input.half())).abs().max()
self.assertTrue(same < 2e-2)
trtorch.logging.log(trtorch.logging.Level.Debug, "Max diff: " + str(same))
self.assertTrue(same < 3e-2)


class TestCompileHalfDefault(ModelTestCase):
Expand All @@ -115,7 +116,8 @@ def test_compile_script_half_by_default(self):

trt_mod = trtorch.compile(self.scripted_model, compile_spec)
same = (trt_mod(self.input.half()) - self.scripted_model(self.input.half())).abs().max()
self.assertTrue(same < 2e-2)
trtorch.logging.log(trtorch.logging.Level.Debug, "Max diff: " + str(same))
self.assertTrue(same < 3e-2)


class TestFallbackToTorch(ModelTestCase):
Expand Down

0 comments on commit d61b667

Please sign in to comment.