Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Ansor][AutoTVM v2.0] Phase 1: Add follow_split and follow_fused_split steps #6142

Merged
merged 35 commits into from
Jul 28, 2020

Conversation

jiuqi-yang
Copy link
Contributor

For the full upstream plan, see Ansor RFC.

In this PR, we bring follow split and follow fused split steps for Ansor auto_scheduler.

cc @merrymercy @comaniac @junrushao1994 @FrozenGene @jroesch

jcf94 and others added 26 commits July 22, 2020 09:55
Signed-off-by: jingbang.yjb <[email protected]>

Conflicts:
	src/auto_scheduler/compute_dag.cc
	src/auto_scheduler/transform_step.cc
	src/auto_scheduler/transform_step.h
	tests/python/unittest/test_auto_scheduler_loop_state.py
Signed-off-by: jingbang.yjb <[email protected]>
Signed-off-by: jingbang.yjb <[email protected]>
Signed-off-by: jingbang.yjb <[email protected]>
1. delete a comment
2. add "fuse" between follow_split and follow_fused_split

Signed-off-by: jingbang.yjb <[email protected]>
Signed-off-by: jingbang.yjb <[email protected]>
Signed-off-by: jingbang.yjb <[email protected]>

Conflicts:
	include/tvm/auto_scheduler/loop_state.h
	include/tvm/auto_scheduler/transform_step.h
	src/auto_scheduler/compute_dag.cc
	src/auto_scheduler/compute_dag.h
	src/auto_scheduler/loop_state.cc
	src/auto_scheduler/transform_step.cc
	tests/python/unittest/test_auto_scheduler_loop_state.py
	tests/python/unittest/test_auto_scheduler_measure.py
include/tvm/auto_scheduler/loop_state.h Outdated Show resolved Hide resolved
include/tvm/auto_scheduler/loop_state.h Outdated Show resolved Hide resolved
include/tvm/auto_scheduler/transform_step.h Outdated Show resolved Hide resolved
src/auto_scheduler/loop_state.cc Outdated Show resolved Hide resolved
src/auto_scheduler/loop_state.cc Outdated Show resolved Hide resolved
src/auto_scheduler/loop_state.cc Outdated Show resolved Hide resolved
src/auto_scheduler/transform_step.cc Outdated Show resolved Hide resolved
tests/python/unittest/test_auto_scheduler_loop_state.py Outdated Show resolved Hide resolved
tests/python/unittest/test_auto_scheduler_measure.py Outdated Show resolved Hide resolved
@jcf94
Copy link
Contributor

jcf94 commented Jul 27, 2020

Hi, all. This is an student intern of us, who is now helping us with the Ansor upstreaming. 😄

The follow_split & follow_fused_split are two steps extent to te.Stage.Split. Each of these will collect information from the former history and process the split.

FollowSplit

This is mainly used in stage fusion using compute at.
For example we have stages: Dense -> Relu:
We've already done some tiling on Relu, and we would like to compute the Dense at the Relu stage. FollowSplit step is used to keep the outer most few iterators of Dense the same as the Relu stage.
Since in Ansor, the split factor of Relu stage may be left as a None placeholder to be filled by search policy, by this way we can easily write a schedule with some kind of dynamic dependence.

FollowFusedSplit

This is mainly used in GPU cooperative fetching.
For example we have stages: Input -> Dense:

            for [email protected] = ... : Bind to blockIdx.x
                for [email protected] = ... : Bind to threadIdx.x
                    for [email protected] = ...
                        Input_shared = Input ...
                        for k = ...
                            Dense = ...

In Ansor's search policy, the outer stage has been tiled. The the threadIdx.x axis is binded to a iterator generated by split & fuse step.
We use FollowFusedSplit step to compute out the final extent of the threadIdx.x binded iterator, to make sure that Input_shared stage can split out a iterator with same extent.

@yangjunpro
Copy link

Thanks @jiuqi-yang for the nice work, @merrymercy @tqchen @FrozenGene @comaniac , would you please take a look at this PR? We are trying to accelerate the auto-schedule upstreaming process.

jingbang.yjb added 2 commits July 27, 2020 15:56
Signed-off-by: jingbang.yjb <[email protected]>
Signed-off-by: jingbang.yjb <[email protected]>
Copy link

@yangjunpro yangjunpro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some nitpicking comments.

src/auto_scheduler/transform_step.cc Show resolved Hide resolved
@@ -136,6 +144,10 @@ void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxes
ps->ApplyToSchedule(stages, stage_to_axes);
} else if (auto ps = step.as<SplitStepNode>()) {
ps->ApplyToSchedule(stages, stage_to_axes);
} else if (auto ps = step.as<FollowSplitStepNode>()) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
writer->WriteArraySeperator();
writer->WriteString(record_prefix_str);
writer->WriteArrayItem(stage_id);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What will happen if the order of writing is changed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The order here corresponds to the read order defined in the constructor of this step.

src/auto_scheduler/transform_step.cc Show resolved Hide resolved
src/auto_scheduler/transform_step.cc Show resolved Hide resolved
python/tvm/auto_scheduler/loop_state.py Outdated Show resolved Hide resolved
python/tvm/auto_scheduler/loop_state.py Outdated Show resolved Hide resolved
python/tvm/auto_scheduler/loop_state.py Outdated Show resolved Hide resolved
include/tvm/auto_scheduler/transform_step.h Outdated Show resolved Hide resolved
src/auto_scheduler/transform_step.cc Show resolved Hide resolved
src/auto_scheduler/transform_step.cc Show resolved Hide resolved
src/auto_scheduler/transform_step.cc Show resolved Hide resolved
src/auto_scheduler/transform_step.cc Show resolved Hide resolved
src/auto_scheduler/transform_step.cc Show resolved Hide resolved
src/auto_scheduler/transform_step.cc Show resolved Hide resolved
include/tvm/auto_scheduler/loop_state.h Outdated Show resolved Hide resolved
include/tvm/auto_scheduler/loop_state.h Outdated Show resolved Hide resolved
python/tvm/auto_scheduler/loop_state.py Outdated Show resolved Hide resolved
src/auto_scheduler/transform_step.cc Show resolved Hide resolved
src/auto_scheduler/transform_step.cc Show resolved Hide resolved
jingbang.yjb added 3 commits July 28, 2020 10:13
Signed-off-by: jingbang.yjb <[email protected]>
Signed-off-by: jingbang.yjb <[email protected]>
Copy link
Contributor

@jcf94 jcf94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Others look good to me. Just update these descriptions to follow the other functions.

include/tvm/auto_scheduler/transform_step.h Outdated Show resolved Hide resolved
include/tvm/auto_scheduler/transform_step.h Outdated Show resolved Hide resolved
include/tvm/auto_scheduler/transform_step.h Outdated Show resolved Hide resolved
include/tvm/auto_scheduler/transform_step.h Outdated Show resolved Hide resolved
include/tvm/auto_scheduler/transform_step.h Outdated Show resolved Hide resolved
include/tvm/auto_scheduler/transform_step.h Outdated Show resolved Hide resolved
jingbang.yjb added 3 commits July 28, 2020 14:20
Signed-off-by: jingbang.yjb <[email protected]>
Signed-off-by: jingbang.yjb <[email protected]>
Copy link
Contributor

@jcf94 jcf94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine to me, I'll rebase the #6141 after this has been merged.

@merrymercy merrymercy merged commit bbc2dbf into apache:master Jul 28, 2020
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Aug 26, 2020
…t steps (apache#6142)

* Add cache_read/cache_write step

* Update

* Add follow split and follow fused split

Signed-off-by: jingbang.yjb <[email protected]>

Conflicts:
	src/auto_scheduler/compute_dag.cc
	src/auto_scheduler/transform_step.cc
	src/auto_scheduler/transform_step.h
	tests/python/unittest/test_auto_scheduler_loop_state.py

* add loop_state.py

Signed-off-by: jingbang.yjb <[email protected]>

* Update

* Update

* Update state->current_compute_dag to Optional

* Add some doc strings for Follow_Split and Follow_fused_split

Signed-off-by: jingbang.yjb <[email protected]>

* Check code using c-lint

Signed-off-by: jingbang.yjb <[email protected]>

* Add more doc strings and change the order for follow split.

Signed-off-by: jingbang.yjb <[email protected]>

* Add record test for follow_split and follow_fused_split

Signed-off-by: jingbang.yjb <[email protected]>

* Add record test for follow_split

Signed-off-by: jingbang.yjb <[email protected]>

* Add record test for follow_fused_split.

Signed-off-by: jingbang.yjb <[email protected]>

* Add test record for follow_fused_split
1. delete a comment
2. add "fuse" between follow_split and follow_fused_split

Signed-off-by: jingbang.yjb <[email protected]>

* Add doc strings for some functions and variables

Signed-off-by: jingbang.yjb <[email protected]>

* Fix the code format in src/auto_scheduler/transform_step.h

Signed-off-by: jingbang.yjb <[email protected]>

* Update

* Update doc

* Update

* Update

* Fix follow_split and follow_fused_split record test.

Signed-off-by: jingbang.yjb <[email protected]>

* Doc update

* Update some doc strings

Signed-off-by: jingbang.yjb <[email protected]>

* Fix code style and some function definitions.

Signed-off-by: jingbang.yjb <[email protected]>

* Update

Signed-off-by: jingbang.yjb <[email protected]>

* Add comments on parameters.

Signed-off-by: jingbang.yjb <[email protected]>

* Add more doc strings and fix some.

Signed-off-by: jingbang.yjb <[email protected]>

* Update

Signed-off-by: jingbang.yjb <[email protected]>

* Update

Signed-off-by: jingbang.yjb <[email protected]>

* Update

Signed-off-by: jingbang.yjb <[email protected]>

* Update.

Signed-off-by: jingbang.yjb <[email protected]>

Co-authored-by: chengfan.jcf <[email protected]>
Co-authored-by: jingbang.yjb <[email protected]>
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Aug 26, 2020
…t steps (apache#6142)

* Add cache_read/cache_write step

* Update

* Add follow split and follow fused split

Signed-off-by: jingbang.yjb <[email protected]>

Conflicts:
	src/auto_scheduler/compute_dag.cc
	src/auto_scheduler/transform_step.cc
	src/auto_scheduler/transform_step.h
	tests/python/unittest/test_auto_scheduler_loop_state.py

* add loop_state.py

Signed-off-by: jingbang.yjb <[email protected]>

* Update

* Update

* Update state->current_compute_dag to Optional

* Add some doc strings for Follow_Split and Follow_fused_split

Signed-off-by: jingbang.yjb <[email protected]>

* Check code using c-lint

Signed-off-by: jingbang.yjb <[email protected]>

* Add more doc strings and change the order for follow split.

Signed-off-by: jingbang.yjb <[email protected]>

* Add record test for follow_split and follow_fused_split

Signed-off-by: jingbang.yjb <[email protected]>

* Add record test for follow_split

Signed-off-by: jingbang.yjb <[email protected]>

* Add record test for follow_fused_split.

Signed-off-by: jingbang.yjb <[email protected]>

* Add test record for follow_fused_split
1. delete a comment
2. add "fuse" between follow_split and follow_fused_split

Signed-off-by: jingbang.yjb <[email protected]>

* Add doc strings for some functions and variables

Signed-off-by: jingbang.yjb <[email protected]>

* Fix the code format in src/auto_scheduler/transform_step.h

Signed-off-by: jingbang.yjb <[email protected]>

* Update

* Update doc

* Update

* Update

* Fix follow_split and follow_fused_split record test.

Signed-off-by: jingbang.yjb <[email protected]>

* Doc update

* Update some doc strings

Signed-off-by: jingbang.yjb <[email protected]>

* Fix code style and some function definitions.

Signed-off-by: jingbang.yjb <[email protected]>

* Update

Signed-off-by: jingbang.yjb <[email protected]>

* Add comments on parameters.

Signed-off-by: jingbang.yjb <[email protected]>

* Add more doc strings and fix some.

Signed-off-by: jingbang.yjb <[email protected]>

* Update

Signed-off-by: jingbang.yjb <[email protected]>

* Update

Signed-off-by: jingbang.yjb <[email protected]>

* Update

Signed-off-by: jingbang.yjb <[email protected]>

* Update.

Signed-off-by: jingbang.yjb <[email protected]>

Co-authored-by: chengfan.jcf <[email protected]>
Co-authored-by: jingbang.yjb <[email protected]>
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Aug 26, 2020
…t steps (apache#6142)

* Add cache_read/cache_write step

* Update

* Add follow split and follow fused split

Signed-off-by: jingbang.yjb <[email protected]>

Conflicts:
	src/auto_scheduler/compute_dag.cc
	src/auto_scheduler/transform_step.cc
	src/auto_scheduler/transform_step.h
	tests/python/unittest/test_auto_scheduler_loop_state.py

* add loop_state.py

Signed-off-by: jingbang.yjb <[email protected]>

* Update

* Update

* Update state->current_compute_dag to Optional

* Add some doc strings for Follow_Split and Follow_fused_split

Signed-off-by: jingbang.yjb <[email protected]>

* Check code using c-lint

Signed-off-by: jingbang.yjb <[email protected]>

* Add more doc strings and change the order for follow split.

Signed-off-by: jingbang.yjb <[email protected]>

* Add record test for follow_split and follow_fused_split

Signed-off-by: jingbang.yjb <[email protected]>

* Add record test for follow_split

Signed-off-by: jingbang.yjb <[email protected]>

* Add record test for follow_fused_split.

Signed-off-by: jingbang.yjb <[email protected]>

* Add test record for follow_fused_split
1. delete a comment
2. add "fuse" between follow_split and follow_fused_split

Signed-off-by: jingbang.yjb <[email protected]>

* Add doc strings for some functions and variables

Signed-off-by: jingbang.yjb <[email protected]>

* Fix the code format in src/auto_scheduler/transform_step.h

Signed-off-by: jingbang.yjb <[email protected]>

* Update

* Update doc

* Update

* Update

* Fix follow_split and follow_fused_split record test.

Signed-off-by: jingbang.yjb <[email protected]>

* Doc update

* Update some doc strings

Signed-off-by: jingbang.yjb <[email protected]>

* Fix code style and some function definitions.

Signed-off-by: jingbang.yjb <[email protected]>

* Update

Signed-off-by: jingbang.yjb <[email protected]>

* Add comments on parameters.

Signed-off-by: jingbang.yjb <[email protected]>

* Add more doc strings and fix some.

Signed-off-by: jingbang.yjb <[email protected]>

* Update

Signed-off-by: jingbang.yjb <[email protected]>

* Update

Signed-off-by: jingbang.yjb <[email protected]>

* Update

Signed-off-by: jingbang.yjb <[email protected]>

* Update.

Signed-off-by: jingbang.yjb <[email protected]>

Co-authored-by: chengfan.jcf <[email protected]>
Co-authored-by: jingbang.yjb <[email protected]>
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Sep 2, 2020
…t steps (apache#6142)

* Add cache_read/cache_write step

* Update

* Add follow split and follow fused split

Signed-off-by: jingbang.yjb <[email protected]>

Conflicts:
	src/auto_scheduler/compute_dag.cc
	src/auto_scheduler/transform_step.cc
	src/auto_scheduler/transform_step.h
	tests/python/unittest/test_auto_scheduler_loop_state.py

* add loop_state.py

Signed-off-by: jingbang.yjb <[email protected]>

* Update

* Update

* Update state->current_compute_dag to Optional

* Add some doc strings for Follow_Split and Follow_fused_split

Signed-off-by: jingbang.yjb <[email protected]>

* Check code using c-lint

Signed-off-by: jingbang.yjb <[email protected]>

* Add more doc strings and change the order for follow split.

Signed-off-by: jingbang.yjb <[email protected]>

* Add record test for follow_split and follow_fused_split

Signed-off-by: jingbang.yjb <[email protected]>

* Add record test for follow_split

Signed-off-by: jingbang.yjb <[email protected]>

* Add record test for follow_fused_split.

Signed-off-by: jingbang.yjb <[email protected]>

* Add test record for follow_fused_split
1. delete a comment
2. add "fuse" between follow_split and follow_fused_split

Signed-off-by: jingbang.yjb <[email protected]>

* Add doc strings for some functions and variables

Signed-off-by: jingbang.yjb <[email protected]>

* Fix the code format in src/auto_scheduler/transform_step.h

Signed-off-by: jingbang.yjb <[email protected]>

* Update

* Update doc

* Update

* Update

* Fix follow_split and follow_fused_split record test.

Signed-off-by: jingbang.yjb <[email protected]>

* Doc update

* Update some doc strings

Signed-off-by: jingbang.yjb <[email protected]>

* Fix code style and some function definitions.

Signed-off-by: jingbang.yjb <[email protected]>

* Update

Signed-off-by: jingbang.yjb <[email protected]>

* Add comments on parameters.

Signed-off-by: jingbang.yjb <[email protected]>

* Add more doc strings and fix some.

Signed-off-by: jingbang.yjb <[email protected]>

* Update

Signed-off-by: jingbang.yjb <[email protected]>

* Update

Signed-off-by: jingbang.yjb <[email protected]>

* Update

Signed-off-by: jingbang.yjb <[email protected]>

* Update.

Signed-off-by: jingbang.yjb <[email protected]>

Co-authored-by: chengfan.jcf <[email protected]>
Co-authored-by: jingbang.yjb <[email protected]>
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Sep 3, 2020
…t steps (apache#6142)

* Add cache_read/cache_write step

* Update

* Add follow split and follow fused split

Signed-off-by: jingbang.yjb <[email protected]>

Conflicts:
	src/auto_scheduler/compute_dag.cc
	src/auto_scheduler/transform_step.cc
	src/auto_scheduler/transform_step.h
	tests/python/unittest/test_auto_scheduler_loop_state.py

* add loop_state.py

Signed-off-by: jingbang.yjb <[email protected]>

* Update

* Update

* Update state->current_compute_dag to Optional

* Add some doc strings for Follow_Split and Follow_fused_split

Signed-off-by: jingbang.yjb <[email protected]>

* Check code using c-lint

Signed-off-by: jingbang.yjb <[email protected]>

* Add more doc strings and change the order for follow split.

Signed-off-by: jingbang.yjb <[email protected]>

* Add record test for follow_split and follow_fused_split

Signed-off-by: jingbang.yjb <[email protected]>

* Add record test for follow_split

Signed-off-by: jingbang.yjb <[email protected]>

* Add record test for follow_fused_split.

Signed-off-by: jingbang.yjb <[email protected]>

* Add test record for follow_fused_split
1. delete a comment
2. add "fuse" between follow_split and follow_fused_split

Signed-off-by: jingbang.yjb <[email protected]>

* Add doc strings for some functions and variables

Signed-off-by: jingbang.yjb <[email protected]>

* Fix the code format in src/auto_scheduler/transform_step.h

Signed-off-by: jingbang.yjb <[email protected]>

* Update

* Update doc

* Update

* Update

* Fix follow_split and follow_fused_split record test.

Signed-off-by: jingbang.yjb <[email protected]>

* Doc update

* Update some doc strings

Signed-off-by: jingbang.yjb <[email protected]>

* Fix code style and some function definitions.

Signed-off-by: jingbang.yjb <[email protected]>

* Update

Signed-off-by: jingbang.yjb <[email protected]>

* Add comments on parameters.

Signed-off-by: jingbang.yjb <[email protected]>

* Add more doc strings and fix some.

Signed-off-by: jingbang.yjb <[email protected]>

* Update

Signed-off-by: jingbang.yjb <[email protected]>

* Update

Signed-off-by: jingbang.yjb <[email protected]>

* Update

Signed-off-by: jingbang.yjb <[email protected]>

* Update.

Signed-off-by: jingbang.yjb <[email protected]>

Co-authored-by: chengfan.jcf <[email protected]>
Co-authored-by: jingbang.yjb <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants