From 33ea5398f4dfcf3f4952e97b9dd0f11efc955ec2 Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Tue, 25 Aug 2020 14:26:31 +0530 Subject: [PATCH] [1] Review comment handled --- python/tvm/tir/transform/transform.py | 18 ++++++++++--- src/tir/transforms/hoist_if_then_else.cc | 5 ++-- .../unittest/test_tir_transform_hoist_if.py | 26 +------------------ 3 files changed, 18 insertions(+), 31 deletions(-) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 467dd7c91330..55dc98c72462 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -505,14 +505,24 @@ def HoistIfThenElse(variant=None): Parameters ---------- - variant : str + variant : str, optional The variant of the pass. + variant can have any one of following values ["basic", None(Default)]. + + The basic variant supports basic hoisting scenarios where it exepects + the For & If Nodes are in place consecutively and does not involve + global scope variables or more advanced scenarios. + + Default variant supports all hoisting scenarios,i.e., {"Basic" + "Advanced"} + supported with control with PassContext configs like below: + + config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}} Returns ------- fpass : tvm.transform.Pass The result pass """ - if variant is None: - return _ffi_api.HoistIfThenElse() - return _ffi_api.HoistIfThenElseBasic() + if variant == "basic": + return _ffi_api.HoistIfThenElseBasic() + return _ffi_api.HoistIfThenElse() diff --git a/src/tir/transforms/hoist_if_then_else.cc b/src/tir/transforms/hoist_if_then_else.cc index 1ac5e1ce701a..4e7589c3a795 100644 --- a/src/tir/transforms/hoist_if_then_else.cc +++ b/src/tir/transforms/hoist_if_then_else.cc @@ -176,9 +176,9 @@ class HoistCandidateSelector final : public StmtExprVisitor { StmtExprVisitor::VisitStmt_(op); RemoveRecord(GetRef(op)); return; + } else { + return StmtExprVisitor::VisitStmt_(op); } - - return StmtExprVisitor::VisitStmt_(op); } UpdateAttrVarList(op); StmtExprVisitor::VisitStmt_(op); @@ -327,6 +327,7 @@ class HoistCandidateSelector final : public StmtExprVisitor { return false; } + // Ordered List maintains all ForNodes & AttrStmtNodes encountered in sequence std::vector ordered_list_; std::vector if_var_list_; std::unordered_set attr_var_list_; diff --git a/tests/python/unittest/test_tir_transform_hoist_if.py b/tests/python/unittest/test_tir_transform_hoist_if.py index b2e23dc1b7f0..186a52d12da1 100644 --- a/tests/python/unittest/test_tir_transform_hoist_if.py +++ b/tests/python/unittest/test_tir_transform_hoist_if.py @@ -743,28 +743,4 @@ def test_hoisting_op_conv(): tvm.testing.assert_allclose(t1.mean(), t2.mean(), atol=1, rtol=1e-1) if __name__ == "__main__": - test_hoist_top_for() - test_hoist_multi_var_if() - test_hoist_no_match_for() - test_no_else() - test_attr_stmt() - test_nested_for() - test_if_block() - test_multi_if() - test_no_hoisting_1() - test_no_hoisting_2() - test_no_hoisting_3() - test_no_hoisting_4() - test_no_hoisting_5() - test_no_hoisting_6() - test_no_hoisting_7() - test_hoisting_block_scope_1() - test_hoisting_block_scope_2() - test_hoisting_block_scope_3() - test_hoisting_block_scope_4() - test_hoisting_block_scope_5() - test_hoisting_block_scope_6() - test_hoisting_block_scope_7() - - # Test with Conv Op - test_hoisting_op_conv() + pytest.main([__file__])