Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Ansor][AutoTVM v2.0] Phase 1: Add follow_split and follow_fused_split steps #6142

Merged
merged 35 commits into from
Jul 28, 2020
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
d12465d
Add cache_read/cache_write step
jcf94 Jul 21, 2020
920f4b1
Update
jcf94 Jul 22, 2020
90e6391
Add follow split and follow fused split
Jul 22, 2020
e144082
add loop_state.py
Jul 22, 2020
86c3670
Update
jcf94 Jul 22, 2020
abfb150
Update
jcf94 Jul 23, 2020
3c1da64
Update state->current_compute_dag to Optional
jcf94 Jul 23, 2020
c4a344c
Add some doc strings for Follow_Split and Follow_fused_split
Jul 23, 2020
d3969b8
Check code using c-lint
Jul 23, 2020
f209525
Add more doc strings and change the order for follow split.
Jul 23, 2020
50f7c4a
Add record test for follow_split and follow_fused_split
Jul 23, 2020
7bf8dd5
Add record test for follow_split
Jul 23, 2020
98d943b
Add record test for follow_fused_split.
Jul 24, 2020
296cb36
Add test record for follow_fused_split
Jul 24, 2020
a7b1294
Add doc strings for some functions and variables
Jul 24, 2020
5220a68
Fix the code format in src/auto_scheduler/transform_step.h
Jul 24, 2020
2a113d3
Update
jcf94 Jul 24, 2020
bf660a8
Update doc
jcf94 Jul 24, 2020
3649e26
Update
jcf94 Jul 24, 2020
85da7e0
Update
jcf94 Jul 24, 2020
1a87244
Fix follow_split and follow_fused_split record test.
Jul 24, 2020
4b09317
Merge branch 'follow_split' into merge_follow_split_cache
Jul 24, 2020
87e703a
Doc update
jcf94 Jul 25, 2020
3eeaff0
Merge branch 'cache_read_write' into merge_follow_split_cache
Jul 27, 2020
704c2cc
Update some doc strings
Jul 27, 2020
e5764a6
Add follow split step and follow fused split step.
Jul 27, 2020
40f9638
Fix code style and some function definitions.
Jul 27, 2020
22b1f3b
Update
Jul 27, 2020
d2e3da6
Add comments on parameters.
Jul 27, 2020
072c868
Add more doc strings and fix some.
Jul 28, 2020
be75123
Update
Jul 28, 2020
f4771f9
Update
Jul 28, 2020
ad94db3
Fix some function parameters description.
Jul 28, 2020
0780757
Update
Jul 28, 2020
4d3a426
Update.
Jul 28, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions include/tvm/auto_scheduler/loop_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,29 @@ class State : public ObjectRef {
TVM_DLL Array<Iterator> split(int stage_id, const Iterator& it,
const Array<Optional<Integer>>& lengths,
bool inner_to_outer = true);
/*!
* \brief Schedule primitive extends to split step.
* \param stage_id The index of the stage to be split.
* \param it The iterator to be split.
* \param src_step_id The index of the split step to follow in the history.
jiuqi-yang marked this conversation as resolved.
Show resolved Hide resolved
* \param n_split The number of split level.
* \return The splitted new Iterators.
*/
TVM_DLL Array<Iterator> follow_split(int stage_id, const Iterator& it, int src_step_id,
int n_split);
/*!
* \brief Schedule primitive extends to split step.
* \param stage_id The index of the stage to be split.
* \param it The iterator to be split.
* \param src_step_ids The indices of the split steps to follow in the history.
jiuqi-yang marked this conversation as resolved.
Show resolved Hide resolved
* \param level Use the length in this split level.
* \param factor_or_nparts True to use `factor` for split from inner to outer,
False to use `nparts` for split from outer to inner.
* \return The splitted new Iterators.
*/
TVM_DLL Array<Iterator> follow_fused_split(int stage_id, const Iterator& it,
const Array<Integer>& src_step_ids, int level,
bool factor_or_nparts);

/********** Step APIs working on multiple stages **********/

Expand Down
167 changes: 165 additions & 2 deletions include/tvm/auto_scheduler/transform_step.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,10 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag);
* \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
* \param schedule A mutable pointer to a `te::Schedule`. This is required by some steps which need
* `te::Schedule` API. (e.g. CacheRead/CacheWrite step)
* \param transform_steps An array record all transform steps.
*/
void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
te::Schedule* schedule);
te::Schedule* schedule, const Array<Step>& transform_steps);

/*!
* \brief Print the step as equivalent python schedule API.
Expand All @@ -213,10 +214,12 @@ void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxes
* \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
* \param schedule A mutable pointer to a te::Schedule. This is required by some steps. (e.g.
* CacheRead/CacheWrite step)
* \param transform_steps An array record all transform steps.
* \return Python schedule code.
*/
String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages,
StageToAxesMap* stage_to_axes, te::Schedule* schedule);
StageToAxesMap* stage_to_axes, te::Schedule* schedule,
const Array<Step>& transform_steps);

/********** Steps working on single stage **********/

Expand Down Expand Up @@ -489,6 +492,166 @@ class SplitStep : public Step {

/********** Steps working on multiple stages **********/

/*! \brief Similar to SplitStepNode, but uses split factors from another step
* (i.e. Follow another split step) */
class FollowSplitStepNode : public StepNode {
public:
/*! \brief The id of the iter to be split. */
int iter_id;
/*! \brief The index of the split step to follow in the history. */
int src_step_id;
/*! \brief The number of split level. */
int n_split;

void WriteToRecord(dmlc::JSONWriter* writer) const final;

/*!
* \brief Extract split lengths.
* \param transform_steps An array record all transform steps.
* \param lengths The multiple split factors. Can be None to be filled by search policy.
*/
void ExtractSplitLengths(const Array<Step>& transform_steps,
Array<Optional<Integer>>* lengths) const;

/*!
* \brief Apply the current step to State.
* \param state A mutable pointer to State.
jiuqi-yang marked this conversation as resolved.
Show resolved Hide resolved
*/
Array<Iterator> ApplyToState(State* state) const;

/*!
* \brief Apply the current step to tvm.schedule.
* \param stages A pointer to a `te::Stage` Array.
* \param stage_to_axes A pointer to a StageToAxesMap.
* \param transform_steps An array record all transform steps.
jiuqi-yang marked this conversation as resolved.
Show resolved Hide resolved
*/
Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
const Array<Step>& transform_steps) const;

/*!
* \brief Print the current step as equivalent python schedule API.
* \param stages A pointer to a `te::Stage` Array.
* \param stage_to_axes A pointer to a StageToAxesMap.
jiuqi-yang marked this conversation as resolved.
Show resolved Hide resolved
* \param transform_steps An array record all transform steps.
* \return Python schedule code.
*/
String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
const Array<Step>& transform_steps) const;

static constexpr const char* record_prefix_str = "FSP";

static constexpr const char* _type_key = "auto_scheduler.FollowSplitStep";
TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object);
};

/*!
* \brief Managed reference to FollowSplitStepNode.
* \sa FollowSplitStepNode
*/
class FollowSplitStep : public Step {
public:
/*!
* \brief The constructor.
* \param stage_id The index of the stage to be split.
* \param iter_id The index of the iterator to be split.
* \param src_step_id The index of the split step to follow in the history.
* \param n_split The number of split level.
*/
FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split);

/*!
* \brief The constructor used to read a step record from JSONReader and create the
* corresponding step.
* \param reader The input JSONReader.
*/
explicit FollowSplitStep(dmlc::JSONReader* reader);

TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode);
};

/*! \brief Similar to FollowSplitStep, but use split factors from multiple steps.
* \note This can be used for the split in cooperative fetching
jiuqi-yang marked this conversation as resolved.
Show resolved Hide resolved
*/
class FollowFusedSplitStepNode : public StepNode {
public:
/*! \brief The id of the iter to split. */
int iter_id;
/*! \brief The indices of the split steps to follow in the history. */
Array<Integer> src_step_ids;
/*! \brief Use the length in this split level. */
int level;
/*! \brief If this is true, use factor. Otherwise, use nparts. */
bool factor_or_nparts;

void WriteToRecord(dmlc::JSONWriter* writer) const final;

/*!
* \brief Extract split length.
* \param transform_steps An array record all transform steps.
* \return Split factor.
*/
Optional<Integer> ExtractSplitLength(const Array<Step>& transform_steps) const;

/*!
* \brief Apply the current step to State.
* \param state A mutable pointer to State.
jiuqi-yang marked this conversation as resolved.
Show resolved Hide resolved
*/
Array<Iterator> ApplyToState(State* state) const;

/*!
* \brief Apply the current step to tvm.schedule.
* \param stages A pointer to a `te::Stage` Array.
* \param stage_to_axes A pointer to a StageToAxesMap.
* \param transform_steps An array record all transform steps.
jiuqi-yang marked this conversation as resolved.
Show resolved Hide resolved
*/
Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
const Array<Step>& transform_steps) const;

/*!
* \brief Print the current step as equivalent python schedule API.
* \param stages A pointer to a `te::Stage` Array.
* \param stage_to_axes A pointer to a StageToAxesMap.
jiuqi-yang marked this conversation as resolved.
Show resolved Hide resolved
* \param transform_steps An array record all transform steps.
* \return Python schedule code.
*/
String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
const Array<Step>& transform_steps) const;

static constexpr const char* record_prefix_str = "FFSP";

static constexpr const char* _type_key = "auto_scheduler.FollowFusedSplitStep";
TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, Object);
};

/*!
* \brief Managed reference to FollowFusedSplitStepNode.
* \sa FollowFusedSplitStepNode
*/
class FollowFusedSplitStep : public Step {
public:
/*!
* \brief The constructor.
* \param stage_id The index of the stage to be split.
* \param iter_id The index of the iterator to be split.
* \param src_step_ids An array of index for split step to follow in the history.
* \param level Use the length in this split level.
* \param factor_or_nparts If this is true, use factor. Otherwise, use nparts.
*/
FollowFusedSplitStep(int stage_id, int iter_id, const Array<Integer>& src_step_ids, int level,
bool factor_or_nparts);

/*!
* \brief The constructor used to read a step record from JSONReader and create the
* corresponding step.
* \param reader The input JSONReader.
*/
explicit FollowFusedSplitStep(dmlc::JSONReader* reader);

TVM_DEFINE_OBJECT_REF_METHODS(FollowFusedSplitStep, Step, FollowFusedSplitStepNode);
};

/********** Steps working on multiple stages **********/

/*! \brief Compute at step that corresponds to te::Stage::compute_at */
class ComputeAtStepNode : public StepNode {
public:
Expand Down
92 changes: 92 additions & 0 deletions python/tvm/auto_scheduler/loop_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,15 @@ def stages(self):
"""
return self.state_object.stages

@property
def transform_steps(self):
"""
Returns
-------
transform_steps : List[transform_steps]
"""
return self.state_object.transform_steps

@property
def stage_ops(self):
"""
Expand Down Expand Up @@ -301,6 +310,89 @@ def split(self, stage, iterator, lengths, inner_to_outer=True):
iterator, lengths, inner_to_outer)
return res

def follow_split(self, stage, iterator, src_step_id, n_split):
""" Schedule primitive extends to split step.

This step is used to follow a former SplitStep, keeps their iterator structures to be same.
jiuqi-yang marked this conversation as resolved.
Show resolved Hide resolved
jiuqi-yang marked this conversation as resolved.
Show resolved Hide resolved

Example cases:
With subgraph: Dense -> Relu
Some tiling structures are used in Relu stage and we intend to compute the Dense
stage at Relu.
The follow_split is used here to keep their outer most few iterators the same for
applying compute at.
jiuqi-yang marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be split, which can be specified by the integer index, Operation,
or output tensor of the stage.
iterator : Iterator
The iterator to split.
src_step_id : int
The index of the split step to follow in the history.
n_split : int
The number of split level.

Returns
-------
res_its : List[Iterator]
The splitted new Iterators.
"""

self.state_object, res = _ffi_api.StateFollowSplit(self.state_object,
self._resolve_stage_id(stage),
iterator,
src_step_id, n_split)
return res

def follow_fused_split(self, stage, iterator, src_step_ids, level,
factor_or_nparts):
""" Schedule primitive extends to split step.

This step is used to follow several former SplitSteps and FuseSteps.

Example cases:
With subgraph in GPU schedule: Input -> Dense
for [email protected] = ... : Bind to blockIdx.x
for [email protected] = ... : Bind to threadIdx.x
for [email protected] = ...
Input_shared = Input ...
for k = ...
Dense = ...
We intend to apply cooperative fetching with the Input stage, while the threadIdx.x
axis is binded to a iterator generated by split & fuse step.
The follow_fused_step is used here to figure out the final extent of the threadIdx.x
binded iterator.
jiuqi-yang marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be split, which can be specified by the integer index, Operation,
or output tensor of the stage.
iterator : Iterator
The iterator to split.
src_step_ids : List[int]
The indices of the split steps to follow in the history.
level : int
Use the length in this split level.
factor_or_nparts : bool
True to use `factor` for split from inner to outer,
False to use `nparts` for split from outer to inner.

Returns
-------
res_its : List[Iterator]
The splitted new Iterators.
"""

self.state_object, res = _ffi_api.StateFollowFusedSplit(self.state_object,
self._resolve_stage_id(stage),
iterator,
src_step_ids, level,
factor_or_nparts)
return res

def compute_at(self, stage, target_stage, target_iter):
""" Schedule primitive corresponds to `te.Stage.compute_at`, see also the `te.Stage` for
more details.
Expand Down
4 changes: 2 additions & 2 deletions src/auto_scheduler/compute_dag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ std::pair<te::Schedule, Array<te::Tensor>> ComputeDAG::ApplySteps(
// Apply the history steps to TVM schedule
// Call each step's ApplyToSchedule method
for (const auto& step : transform_steps) {
StepApplyToSchedule(step, stages, stage_to_axes, &schedule);
StepApplyToSchedule(step, stages, stage_to_axes, &schedule, transform_steps);
}

return std::make_pair(schedule, operator->()->tensors);
Expand Down Expand Up @@ -722,7 +722,7 @@ String ComputeDAG::PrintStepsAsPython(const Array<Step>& transform_steps) const
}
// Call each step's PrintAsPythonAPI method
for (const auto& step : transform_steps) {
ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes, &schedule);
ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes, &schedule, transform_steps);
}

return ss.str();
Expand Down
34 changes: 34 additions & 0 deletions src/auto_scheduler/loop_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,25 @@ Array<Iterator> State::split(int stage_id, const Iterator& it,
return step->ApplyToState(this);
}

Array<Iterator> State::follow_split(int stage_id, const Iterator& it, int src_step_id,
int n_split) {
const Stage& stage = operator->()->stages[stage_id];
FollowSplitStep step =
FollowSplitStep(stage_id, GetIndex(stage->iters, it), src_step_id, n_split);
CopyOnWrite()->transform_steps.push_back(step);
return step->ApplyToState(this);
}

Array<Iterator> State::follow_fused_split(int stage_id, const Iterator& it,
const Array<Integer>& src_step_ids, int level,
bool factor_or_nparts) {
const Stage& stage = operator->()->stages[stage_id];
FollowFusedSplitStep step = FollowFusedSplitStep(stage_id, GetIndex(stage->iters, it),
src_step_ids, level, factor_or_nparts);
CopyOnWrite()->transform_steps.push_back(step);
return step->ApplyToState(this);
}

void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) {
const Stage& target_stage = operator->()->stages[target_stage_id];
ComputeAtStep step =
Expand Down Expand Up @@ -454,6 +473,21 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateSplit")
return Array<ObjectRef>{state, res};
});

TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowSplit")
.set_body_typed([](State state, int stage_id, const Iterator& it, int src_step_id,
int n_split) {
const auto& res = state.follow_split(stage_id, it, src_step_id, n_split);
return Array<ObjectRef>{state, Array<Iterator>(res)};
});

TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowFusedSplit")
.set_body_typed([](State state, int stage_id, const Iterator& it,
const Array<Integer>& src_step_ids, int level, bool factor_or_nparts) {
const auto& res =
state.follow_fused_split(stage_id, it, src_step_ids, level, factor_or_nparts);
return Array<ObjectRef>{state, Array<Iterator>(res)};
});

TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeAt")
.set_body_typed([](State state, int stage_id, int target_stage_id,
const Iterator& target_iter) {
Expand Down
Loading