Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Follow split #1

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions python/tvm/auto_scheduler/loop_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,15 @@ def stages(self):
"""
return self.state_object.stages

@property
def transform_steps(self):
"""
Returns
-------
transform_steps : List[transform_steps]
"""
return self.state_object.transform_steps

@property
def stage_ops(self):
"""
Expand Down Expand Up @@ -294,6 +303,57 @@ def split(self, stage, iterator, lengths, inner_to_outer=True):
iterator, lengths, inner_to_outer)
return res

def follow_split(self, stage, iterator, src_step_id, n_split):
"""
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
Parameters
----------
iterator : Iterator
The iterator to split
src_step_id : int
The index of the split step to follow in the history
n_split : int
The number of split level

Returns
-------
res_its : List[Iterator]
The splitted new Iterators
"""

self.state_object, res = _ffi_api.StateFollowSplit(self.state_object,
self._resolve_stage_id(stage),
iterator,
src_step_id, n_split)
return res

def follow_fused_split(self, stage, iterator, src_step_ids, level,
factor_or_nparts):
"""
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
Parameters
----------
iterator : Iterator
The iterator to split
src_step_ids : List[int]
The indices of the split steps to follow in the history
level : int
Use the length in this split level
factor_or_nparts : bool
True to use `factor` for split from inner to outer,
False to use `nparts` for split from outer to inner

Returns
-------
res_its : List[Iterator]
The splitted new Iterators
"""

self.state_object, res = _ffi_api.StateFollowFusedSplit(self.state_object,
self._resolve_stage_id(stage),
iterator,
src_step_ids, level,
factor_or_nparts)
return res

def compute_at(self, stage, target_stage, target_iter):
""" Schedule primitive corresponds to te.compute_at.

Expand Down
4 changes: 2 additions & 2 deletions src/auto_scheduler/compute_dag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ std::pair<te::Schedule, Array<te::Tensor>> ComputeDAG::ApplySteps(
// Apply the history steps to TVM schedule
// Call each step's ApplyToSchedule method
for (const auto& step : transform_steps) {
StepApplyToSchedule(step, stages, stage_to_axes, &schedule);
StepApplyToSchedule(step, stages, stage_to_axes, &schedule, transform_steps);
}

return std::make_pair(schedule, operator->()->tensors);
Expand Down Expand Up @@ -298,7 +298,7 @@ String ComputeDAG::PrintStepsAsPython(const Array<Step>& transform_steps) const
}
// Call each step's PrintAsPythonAPI method
for (const auto& step : transform_steps) {
ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes, &schedule);
ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes, &schedule, transform_steps);
}

return ss.str();
Expand Down
42 changes: 42 additions & 0 deletions src/auto_scheduler/loop_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,28 @@ Iterator State::vectorize(int stage_id, const Iterator& it) {
return step->ApplyToState(this);
}

Array<Iterator> State::follow_split(int stage_id, const Iterator& it,
int src_step_id, int n_split) {
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
const Stage& stage = operator->()->stages[stage_id];

FollowSplitStep step = FollowSplitStep(
stage_id, GetIndex(stage->iters, it), src_step_id, n_split);
CopyOnWrite()->transform_steps.push_back(step);
return step->ApplyToState(this);
}

Array<Iterator> State::follow_fused_split(
int stage_id, const Iterator& it, const Array<Integer>& src_step_ids,
int level, bool factor_or_nparts) {
const Stage& stage = operator->()->stages[stage_id];

FollowFusedSplitStep step =
FollowFusedSplitStep(stage_id, GetIndex(stage->iters, it),
src_step_ids, level, factor_or_nparts);
CopyOnWrite()->transform_steps.push_back(step);
return step->ApplyToState(this);
}

Iterator State::fuse(int stage_id, const Array<Iterator>& iters) {
const Stage& stage = operator->()->stages[stage_id];
Array<Integer> indices;
Expand Down Expand Up @@ -436,6 +458,26 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateVectorize")
return Array<ObjectRef>{state, res};
});

TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowSplit")
.set_body_typed([](State state, int stage_id, const Iterator& it,
int src_step_id, int n_split) {
const auto& res = state.follow_split(stage_id, it, src_step_id, n_split);
return Array<ObjectRef>{state, Array<Iterator>(res)};
});

TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowFusedSplit")
.set_body_typed([](State state, int stage_id, const Iterator& it,
const Array<IntImm>& src_step_ids, int level,
bool factor_or_nparts) {
Array<Integer> array_src_step_ids;
for (const auto& i : src_step_ids) {
array_src_step_ids.push_back(i->value);
}
const auto& res = state.follow_fused_split(
stage_id, it, array_src_step_ids, level, factor_or_nparts);
return Array<ObjectRef>{state, Array<Iterator>(res)};
});

TVM_REGISTER_GLOBAL("auto_scheduler.StateFuse")
.set_body_typed([](State state, int stage_id, const Array<Iterator>& iters) {
const auto& res = state.fuse(stage_id, iters);
Expand Down
4 changes: 3 additions & 1 deletion src/auto_scheduler/loop_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,10 @@ class State : public ObjectRef {
*/
Array<Iterator> split(int stage_id, const Iterator& it, const Array<Optional<Integer>>& lengths,
bool inner_to_outer = true);

jcf94 marked this conversation as resolved.
Show resolved Hide resolved
Array<Iterator> follow_split(int stage_id, const Iterator& it, int src_step_id, int n_split);

/********** Step APIs working on multiple stages **********/
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
Array<Iterator> follow_fused_split(int stage_id, const Iterator& it, const Array<Integer>& src_step_ids, int level, bool factor_or_nparts);

/*!
* \brief Schedule primitive corresponds to te.compute_at.
Expand Down
Loading