Skip to content

Commit

Permalink
[1] Review comment handled
Browse files Browse the repository at this point in the history
  • Loading branch information
ANSHUMAN TRIPATHY authored and ANSHUMAN TRIPATHY committed Aug 25, 2020
1 parent cd54666 commit 33ea539
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 31 deletions.
18 changes: 14 additions & 4 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,14 +505,24 @@ def HoistIfThenElse(variant=None):
Parameters
----------
variant : str
variant : str, optional
The variant of the pass.
variant can have any one of following values ["basic", None(Default)].
The basic variant supports basic hoisting scenarios where it exepects
the For & If Nodes are in place consecutively and does not involve
global scope variables or more advanced scenarios.
Default variant supports all hoisting scenarios,i.e., {"Basic" + "Advanced"}
supported with control with PassContext configs like below:
config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}}
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
if variant is None:
return _ffi_api.HoistIfThenElse()
return _ffi_api.HoistIfThenElseBasic()
if variant == "basic":
return _ffi_api.HoistIfThenElseBasic()
return _ffi_api.HoistIfThenElse()
5 changes: 3 additions & 2 deletions src/tir/transforms/hoist_if_then_else.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,9 @@ class HoistCandidateSelector final : public StmtExprVisitor {
StmtExprVisitor::VisitStmt_(op);
RemoveRecord(GetRef<ObjectRef>(op));
return;
} else {
return StmtExprVisitor::VisitStmt_(op);
}

return StmtExprVisitor::VisitStmt_(op);
}
UpdateAttrVarList(op);
StmtExprVisitor::VisitStmt_(op);
Expand Down Expand Up @@ -327,6 +327,7 @@ class HoistCandidateSelector final : public StmtExprVisitor {
return false;
}

// Ordered List maintains all ForNodes & AttrStmtNodes encountered in sequence
std::vector<const Object*> ordered_list_;
std::vector<const VarNode*> if_var_list_;
std::unordered_set<const VarNode*> attr_var_list_;
Expand Down
26 changes: 1 addition & 25 deletions tests/python/unittest/test_tir_transform_hoist_if.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,28 +743,4 @@ def test_hoisting_op_conv():
tvm.testing.assert_allclose(t1.mean(), t2.mean(), atol=1, rtol=1e-1)

if __name__ == "__main__":
test_hoist_top_for()
test_hoist_multi_var_if()
test_hoist_no_match_for()
test_no_else()
test_attr_stmt()
test_nested_for()
test_if_block()
test_multi_if()
test_no_hoisting_1()
test_no_hoisting_2()
test_no_hoisting_3()
test_no_hoisting_4()
test_no_hoisting_5()
test_no_hoisting_6()
test_no_hoisting_7()
test_hoisting_block_scope_1()
test_hoisting_block_scope_2()
test_hoisting_block_scope_3()
test_hoisting_block_scope_4()
test_hoisting_block_scope_5()
test_hoisting_block_scope_6()
test_hoisting_block_scope_7()

# Test with Conv Op
test_hoisting_op_conv()
pytest.main([__file__])

0 comments on commit 33ea539

Please sign in to comment.