Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Follow split #1

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions python/tvm/auto_scheduler/loop_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,15 @@ def stages(self):
"""
return self.state_object.stages

@property
def transform_steps(self):
"""
Returns
-------
transform_steps : List[transform_steps]
"""
return self.state_object.transform_steps

@property
def stage_ops(self):
"""
Expand Down Expand Up @@ -294,6 +303,89 @@ def split(self, stage, iterator, lengths, inner_to_outer=True):
iterator, lengths, inner_to_outer)
return res

def follow_split(self, stage, iterator, src_step_id, n_split):
""" Schedule primitive extends to split step.

This step is used to follow a former SplitStep, keeps their iterator structures to be same.

Example cases:
With subgraph: Dense -> Relu
Some tiling structures are used in Relu stage and we intend to compute the Dense
stage at Relu.
The follow_split is used here to keep their outer most few iterators the same for
applying compute at.

Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be split, which can be specified by the integer index, Operation,
or output tensor of the stage.
iterator : Iterator
The iterator to split.
src_step_id : int
The index of the split step to follow in the history.
n_split : int
The number of split level.

Returns
-------
res_its : List[Iterator]
The splitted new Iterators.
"""

self.state_object, res = _ffi_api.StateFollowSplit(self.state_object,
self._resolve_stage_id(stage),
iterator,
src_step_id, n_split)
return res

def follow_fused_split(self, stage, iterator, src_step_ids, level,
factor_or_nparts):
""" Schedule primitive extends to split step.

This step is used to follow several former SplitSteps and FuseSteps.

Example cases:
With subgraph in GPU schedule: Input -> Dense
for [email protected] = ... : Bind to blockIdx.x
for [email protected] = ... : Bind to threadIdx.x
for [email protected] = ...
Input_shared = Input ...
for k = ...
Dense = ...
We intend to apply cooperative fetching with the Input stage, while the threadIdx.x
axis is binded to a iterator generated by split & fuse step.
The follow_fused_step is used here to figure out the final extent of the threadIdx.x
binded iterator.

Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be split, which can be specified by the integer index, Operation,
or output tensor of the stage.
iterator : Iterator
The iterator to split.
src_step_ids : List[int]
The indices of the split steps to follow in the history.
level : int
Use the length in this split level.
factor_or_nparts : bool
True to use `factor` for split from inner to outer,
False to use `nparts` for split from outer to inner.

Returns
-------
res_its : List[Iterator]
The splitted new Iterators.
"""

self.state_object, res = _ffi_api.StateFollowFusedSplit(self.state_object,
self._resolve_stage_id(stage),
iterator,
src_step_ids, level,
factor_or_nparts)
return res

def compute_at(self, stage, target_stage, target_iter):
""" Schedule primitive corresponds to te.compute_at.

Expand Down
4 changes: 2 additions & 2 deletions src/auto_scheduler/compute_dag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ std::pair<te::Schedule, Array<te::Tensor>> ComputeDAG::ApplySteps(
// Apply the history steps to TVM schedule
// Call each step's ApplyToSchedule method
for (const auto& step : transform_steps) {
StepApplyToSchedule(step, stages, stage_to_axes, &schedule);
StepApplyToSchedule(step, stages, stage_to_axes, &schedule, transform_steps);
}

return std::make_pair(schedule, operator->()->tensors);
Expand Down Expand Up @@ -298,7 +298,7 @@ String ComputeDAG::PrintStepsAsPython(const Array<Step>& transform_steps) const
}
// Call each step's PrintAsPythonAPI method
for (const auto& step : transform_steps) {
ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes, &schedule);
ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes, &schedule, transform_steps);
}

return ss.str();
Expand Down
40 changes: 40 additions & 0 deletions src/auto_scheduler/loop_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,27 @@ Array<Iterator> State::split(int stage_id, const Iterator& it,
return step->ApplyToState(this);
}

Array<Iterator> State::follow_split(int stage_id, const Iterator& it, int src_step_id,
int n_split) {
const Stage& stage = operator->()->stages[stage_id];

FollowSplitStep step =
FollowSplitStep(stage_id, GetIndex(stage->iters, it), src_step_id, n_split);
CopyOnWrite()->transform_steps.push_back(step);
return step->ApplyToState(this);
}

Array<Iterator> State::follow_fused_split(int stage_id, const Iterator& it,
const Array<Integer>& src_step_ids, int level,
bool factor_or_nparts) {
const Stage& stage = operator->()->stages[stage_id];

FollowFusedSplitStep step = FollowFusedSplitStep(stage_id, GetIndex(stage->iters, it),
src_step_ids, level, factor_or_nparts);
CopyOnWrite()->transform_steps.push_back(step);
return step->ApplyToState(this);
}

void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) {
const Stage& target_stage = operator->()->stages[target_stage_id];
ComputeAtStep step =
Expand Down Expand Up @@ -455,6 +476,25 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateSplit")
return Array<ObjectRef>{state, res};
});

TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowSplit")
.set_body_typed([](State state, int stage_id, const Iterator& it, int src_step_id,
int n_split) {
const auto& res = state.follow_split(stage_id, it, src_step_id, n_split);
return Array<ObjectRef>{state, Array<Iterator>(res)};
});

TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowFusedSplit")
.set_body_typed([](State state, int stage_id, const Iterator& it,
const Array<IntImm>& src_step_ids, int level, bool factor_or_nparts) {
Array<Integer> array_src_step_ids;
for (const auto& i : src_step_ids) {
array_src_step_ids.push_back(i->value);
}
const auto& res =
state.follow_fused_split(stage_id, it, array_src_step_ids, level, factor_or_nparts);
return Array<ObjectRef>{state, Array<Iterator>(res)};
});

TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeAt")
.set_body_typed([](State state, int stage_id, int target_stage_id,
const Iterator& target_iter) {
Expand Down
22 changes: 22 additions & 0 deletions src/auto_scheduler/loop_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,28 @@ class State : public ObjectRef {
*/
Array<Iterator> split(int stage_id, const Iterator& it, const Array<Optional<Integer>>& lengths,
bool inner_to_outer = true);
/*!
* \brief Schedule primitive corresponds to te.follow_split.
* \param stage_id The index of the stage to be split.
* \param it The iterator to be split.
* \param src_step_id The index of the split step to follow in the history.
* \param n_split The number of split level.
* \return The splitted new Iterators.
*/
Array<Iterator> follow_split(int stage_id, const Iterator& it, int src_step_id, int n_split);
/*!
* \brief Schedule primitive corresponds to te.follow_split.
* \param stage_id The index of the stage to be split.
* \param it The iterator to be split.
* \param src_step_ids The indices of the split steps to follow in the history.
* \param level Use the length in this split level.
* \param factor_or_nparts True to use `factor` for split from inner to outer,
False to use `nparts` for split from outer to inner.
* \return The splitted new Iterators.
*/
Array<Iterator> follow_fused_split(int stage_id, const Iterator& it,
const Array<Integer>& src_step_ids, int level,
bool factor_or_nparts);

/********** Step APIs working on multiple stages **********/
jcf94 marked this conversation as resolved.
Show resolved Hide resolved

Expand Down
Loading