Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Ansor][AutoTVM v2.0] Phase 1: Add annotation/compute_at/compute_root/compute_inline steps #6073

Merged
merged 15 commits into from
Jul 21, 2020
163 changes: 152 additions & 11 deletions python/tvm/auto_scheduler/loop_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,51 @@ def reorder(self, stage, order):
order : List[Iterator]
Iterators in the expected order.
"""
stage_id = self._resolve_stage_id(stage)
self.state_object = _ffi_api.StateReorder(self.state_object, self._resolve_stage_id(stage),
order)

self.state_object = _ffi_api.StateReorder(self.state_object, stage_id, order)
def compute_at(self, stage, target_stage, target_iter):
""" Schedule primitive corresponds to te.compute_at.

Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be compute at, can be a Stage order index, Stage operation or stage
output tensor.
target_stage : Union[int, Operation, Tensor]
The target stage of compute_at, can be a Stage order index, Stage operation or stage
output tensor.
target_iter : Iterator
The target Iterator of compute_at.
"""
self.state_object = _ffi_api.StateComputeAt(self.state_object,
self._resolve_stage_id(stage),
self._resolve_stage_id(target_stage),
target_iter)

def compute_root(self, stage):
""" Schedule primitive corresponds to te.compute_root.

Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be compute root, can be a Stage order index, Stage operation or stage
output tensor.
"""
self.state_object = _ffi_api.StateComputeRoot(self.state_object,
self._resolve_stage_id(stage))

def compute_inline(self, stage):
""" Schedule primitive corresponds to te.compute_inline.

Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be compute inline, can be a Stage order index, Stage operation or stage
output tensor.
"""
self.state_object = _ffi_api.StateComputeInline(self.state_object,
self._resolve_stage_id(stage))

def split(self, stage, iterator, lengths, inner_to_outer=True):
""" Schedule primitive corresponds to te.split.
Expand All @@ -144,12 +186,11 @@ def split(self, stage, iterator, lengths, inner_to_outer=True):
Returns
-------
res_its : List[Iterator]
The splitted new Iterators
The splitted new Iterators.
"""
stage_id = self._resolve_stage_id(stage)

self.state_object, res = _ffi_api.StateSplit(self.state_object, stage_id, iterator, lengths,
inner_to_outer)
self.state_object, res = _ffi_api.StateSplit(self.state_object,
self._resolve_stage_id(stage),
iterator, lengths, inner_to_outer)
return res

def fuse(self, stage, iters):
Expand All @@ -161,16 +202,116 @@ def fuse(self, stage, iters):
The Stage to be fused, can be a Stage order index, Stage operation or stage
output tensor.
iters : List[Iterator]
The iterators to be fused
The iterators to be fused.

Returns
-------
res_it : Iterator
The fused Iterator.
"""
self.state_object, res = _ffi_api.StateFuse(self.state_object,
self._resolve_stage_id(stage), iters)
return res

def vectorize(self, stage, iterator):
""" Schedule primitive corresponds to te.vectorize.

Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be vectorized, can be a Stage order index, Stage operation or stage
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
output tensor.
iterator : Iterator
The iterator to be vectorized.

Returns
-------
res_it : Iterator
The fused Iterator
The vectorized Iterator.
"""
stage_id = self._resolve_stage_id(stage)
self.state_object, res = _ffi_api.StateVectorize(self.state_object,
self._resolve_stage_id(stage), iterator)
return res

def parallel(self, stage, iterator):
""" Schedule primitive corresponds to te.parallel.

self.state_object, res = _ffi_api.StateFuse(self.state_object, stage_id, iters)
Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be paralleled, can be a Stage order index, Stage operation or stage
output tensor.
iterator : Iterator
The iterator to be paralleled.

Returns
-------
res_it : Iterator
The paralleled Iterator.
"""
self.state_object, res = _ffi_api.StateParallel(self.state_object,
self._resolve_stage_id(stage), iterator)
return res

def unroll(self, stage, iterator, max_unroll=None):
""" Schedule primitive corresponds to te.unroll.

Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be unrolled, can be a Stage order index, Stage operation or stage
output tensor.
iterator : Iterator
The iterator to be unrolled.
max_unroll : Optional[int]
The max unroll limit. Iterator with extent larger than this limit will be skipped.

Returns
-------
res_it : Iterator
The unrolled Iterator.
"""
self.state_object, res = _ffi_api.StateUnroll(self.state_object,
self._resolve_stage_id(stage), iterator,
max_unroll if max_unroll else -1)
return res

def bind(self, stage, iterator, thread_name):
""" Schedule primitive corresponds to te.bind.

Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be binded, can be a Stage order index, Stage operation or stage
output tensor.
iterator : Iterator
The iterator to be binded.
thread_name : str
The thread type to be binded. Currently support:
- vthread
- blockIdx.x
- threadIdx.x
- blockIdx.y
- threadIdx.y
jcf94 marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
res_it : Iterator
The binded Iterator.
"""
trans_table = {
"vthread": 4,
"blockIdx.x": 5,
"threadIdx.x": 6,
"blockIdx.y": 7,
"threadIdx.y": 8,
}
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
if not thread_name in trans_table.keys():
raise ValueError("Invalid thread_name: ", thread_name)

self.state_object, res = _ffi_api.StateBind(self.state_object,
self._resolve_stage_id(stage), iterator,
trans_table[thread_name])
return res

def copy(self):
Expand Down
16 changes: 16 additions & 0 deletions src/auto_scheduler/compute_dag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,18 @@ std::pair<te::Schedule, Array<te::Tensor>> ComputeDAG::ApplySteps(
// return value, so the ApplyToSchedule is not able to be merged to single interface
if (auto ps = step.as<ReorderStepNode>()) {
ps->ApplyToSchedule(stages, stage_to_axes);
} else if (auto ps = step.as<ComputeAtStepNode>()) {
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
ps->ApplyToSchedule(stages, stage_to_axes);
} else if (auto ps = step.as<ComputeRootStepNode>()) {
ps->ApplyToSchedule(stages, stage_to_axes);
} else if (auto ps = step.as<ComputeInlineStepNode>()) {
ps->ApplyToSchedule(stages, stage_to_axes);
} else if (auto ps = step.as<SplitStepNode>()) {
ps->ApplyToSchedule(stages, stage_to_axes);
} else if (auto ps = step.as<FuseStepNode>()) {
ps->ApplyToSchedule(stages, stage_to_axes);
} else if (auto ps = step.as<AnnotationStepNode>()) {
ps->ApplyToSchedule(stages, stage_to_axes);
} else {
LOG(FATAL) << "Invalid Step";
}
Expand Down Expand Up @@ -328,10 +336,18 @@ String ComputeDAG::PrintStepsAsPython(const Array<Step>& transform_steps) const
for (const auto& step : transform_steps) {
if (auto ps = step.as<ReorderStepNode>()) {
ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes);
} else if (auto ps = step.as<ComputeAtStepNode>()) {
ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes);
} else if (auto ps = step.as<ComputeRootStepNode>()) {
ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes);
} else if (auto ps = step.as<ComputeInlineStepNode>()) {
ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes);
} else if (auto ps = step.as<SplitStepNode>()) {
ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes);
} else if (auto ps = step.as<FuseStepNode>()) {
ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes);
} else if (auto ps = step.as<AnnotationStepNode>()) {
ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes);
} else {
LOG(FATAL) << "Invalid Step";
}
Expand Down
Loading