forked from neo-ai/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Ansor][AutoTVM v2.0] Phase 1: Add follow_split and follow_fused_spli…
…t steps (apache#6142) * Add cache_read/cache_write step * Update * Add follow split and follow fused split Signed-off-by: jingbang.yjb <[email protected]> Conflicts: src/auto_scheduler/compute_dag.cc src/auto_scheduler/transform_step.cc src/auto_scheduler/transform_step.h tests/python/unittest/test_auto_scheduler_loop_state.py * add loop_state.py Signed-off-by: jingbang.yjb <[email protected]> * Update * Update * Update state->current_compute_dag to Optional * Add some doc strings for Follow_Split and Follow_fused_split Signed-off-by: jingbang.yjb <[email protected]> * Check code using c-lint Signed-off-by: jingbang.yjb <[email protected]> * Add more doc strings and change the order for follow split. Signed-off-by: jingbang.yjb <[email protected]> * Add record test for follow_split and follow_fused_split Signed-off-by: jingbang.yjb <[email protected]> * Add record test for follow_split Signed-off-by: jingbang.yjb <[email protected]> * Add record test for follow_fused_split. Signed-off-by: jingbang.yjb <[email protected]> * Add test record for follow_fused_split 1. delete a comment 2. add "fuse" between follow_split and follow_fused_split Signed-off-by: jingbang.yjb <[email protected]> * Add doc strings for some functions and variables Signed-off-by: jingbang.yjb <[email protected]> * Fix the code format in src/auto_scheduler/transform_step.h Signed-off-by: jingbang.yjb <[email protected]> * Update * Update doc * Update * Update * Fix follow_split and follow_fused_split record test. Signed-off-by: jingbang.yjb <[email protected]> * Doc update * Update some doc strings Signed-off-by: jingbang.yjb <[email protected]> * Fix code style and some function definitions. Signed-off-by: jingbang.yjb <[email protected]> * Update Signed-off-by: jingbang.yjb <[email protected]> * Add comments on parameters. Signed-off-by: jingbang.yjb <[email protected]> * Add more doc strings and fix some. Signed-off-by: jingbang.yjb <[email protected]> * Update Signed-off-by: jingbang.yjb <[email protected]> * Update Signed-off-by: jingbang.yjb <[email protected]> * Update Signed-off-by: jingbang.yjb <[email protected]> * Update. Signed-off-by: jingbang.yjb <[email protected]> Co-authored-by: chengfan.jcf <[email protected]> Co-authored-by: jingbang.yjb <[email protected]>
- Loading branch information
1 parent
1a52e18
commit dd99d55
Showing
8 changed files
with
589 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -117,6 +117,15 @@ def stages(self): | |
""" | ||
return self.state_object.stages | ||
|
||
@property | ||
def transform_steps(self): | ||
""" | ||
Returns | ||
------- | ||
transform_steps : List[transform_steps] | ||
""" | ||
return self.state_object.transform_steps | ||
|
||
@property | ||
def stage_ops(self): | ||
""" | ||
|
@@ -301,6 +310,93 @@ def split(self, stage, iterator, lengths, inner_to_outer=True): | |
iterator, lengths, inner_to_outer) | ||
return res | ||
|
||
def follow_split(self, stage, iterator, src_step_id, n_split): | ||
""" Schedule primitive extends to split step. | ||
This step splits the iterator by the same factors as the given SplitStep. | ||
Notes | ||
------ | ||
This step is useful in a scenario that we have subgraph Dense -> Relu, | ||
and we want to compute the Dense stage at ReLU. In this case, we need them to have | ||
the same tiling structure of common outer loops. | ||
The follow_split step could be used here to split the Dense stage and makes sure its | ||
splitting factors are the same as the given split step for the ReLU stage. | ||
Parameters | ||
---------- | ||
stage : Union[int, Operation, Tensor] | ||
The Stage to be split, which can be specified by the integer index, Operation, | ||
or output tensor of the stage. | ||
iterator : Iterator | ||
The iterator to split. | ||
src_step_id : int | ||
The index of the split step to follow in the history. | ||
n_split : int | ||
The number of split level. | ||
Returns | ||
------- | ||
res_its : List[Iterator] | ||
The splitted new Iterators. | ||
""" | ||
|
||
self.state_object, res = _ffi_api.StateFollowSplit(self.state_object, | ||
self._resolve_stage_id(stage), | ||
iterator, | ||
src_step_id, n_split) | ||
return res | ||
|
||
def follow_fused_split(self, stage, iterator, src_step_ids, level, | ||
factor_or_nparts): | ||
""" Schedule primitive extends to split step. | ||
This step is used to split an iterator by the same factors | ||
as the given list of SplitSteps and FuseSteps. | ||
Notes | ||
------ | ||
This step is useful in a scenario that we have a subgraph | ||
in GPU schedule: Input -> Dense | ||
for [email protected] = ... : Bind to blockIdx.x | ||
for [email protected] = ... : Bind to threadIdx.x | ||
for [email protected] = ... | ||
Input_shared = Input ... | ||
for k = ... | ||
Dense = ... | ||
We intend to apply cooperative fetching with the input stage, while the threadIdx.x | ||
axis is bound to an iterator generated by split & fuse step. | ||
The follow_fused_step is used split the iterator to 2 parts, while the split factor | ||
matches the final extent of the threadIdx.x bound iterator. | ||
Parameters | ||
---------- | ||
stage : Union[int, Operation, Tensor] | ||
The Stage to be split, which can be specified by the integer index, Operation, | ||
or output tensor of the stage. | ||
iterator : Iterator | ||
The iterator to split. | ||
src_step_ids : List[int] | ||
The indices of the split steps to follow in the history. | ||
level : int | ||
Use the length in this split level. | ||
factor_or_nparts : bool | ||
True to use `factor` for split from inner to outer, | ||
False to use `nparts` for split from outer to inner. | ||
Returns | ||
------- | ||
res_its : List[Iterator] | ||
The splitted new Iterators. | ||
""" | ||
|
||
self.state_object, res = _ffi_api.StateFollowFusedSplit(self.state_object, | ||
self._resolve_stage_id(stage), | ||
iterator, | ||
src_step_ids, level, | ||
factor_or_nparts) | ||
return res | ||
|
||
def compute_at(self, stage, target_stage, target_iter): | ||
""" Schedule primitive corresponds to `te.Stage.compute_at`, see also the `te.Stage` for | ||
more details. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.