Skip to content

Commit

Permalink
[TIR] Add schedule primitive TransformBlockLayout (#11485)
Browse files Browse the repository at this point in the history
* [TIR] Add schedule primitive TransformBlockLayout

* fixup! [TIR] Add schedule primitive TransformBlockLayout

Fix doc
  • Loading branch information
vinx13 authored May 29, 2022
1 parent dd2897c commit d4a3968
Show file tree
Hide file tree
Showing 15 changed files with 635 additions and 26 deletions.
10 changes: 10 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,16 @@ class ScheduleNode : public runtime::Object {
virtual void TransformLayout(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type, const IndexMap& index_map) = 0;

/*!
* \brief Apply a transformation represented by IndexMap to block
* \details The block iters and the block body are transformed by the given index_map.
* Outer loops corresponding to each new block iter are regenerated.
* The index_map is required to be bijective affine since we need its inverse mapping.
* \param block_rv The block to be transformed
* \param index_map The transformation to apply.
*/
virtual void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) = 0;

/*!
* \brief Set the axis separator of a buffer, where the buffer is specified by a block and a read
* or write index
Expand Down
61 changes: 61 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2286,6 +2286,67 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) ->
self, block, buffer_index, buffer_index_type_enum, axis_separators
)

@type_checked
def transform_block_layout(
self,
block: BlockRV,
index_map: Union[IndexMap, Callable],
) -> None:
"""Apply a transformation represented by IndexMap to block
Parameters
----------
block : BlockRV
The block to be transformed
index_map : Union[IndexMap, Callable]
The transformation to apply.
Examples
--------
Before transform_block_layout, in TensorIR, the IR is:
.. code-block:: python
@T.prim_func
def before_transform_block_layout(
A: T.Buffer[(16, 16), "float32"],
B: T.Buffer[(16, 16), "float32"]
) -> None:
for i, j in T.grid(16, 16):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
Create the schedule and do transform_block_layout:
.. code-block:: python
sch = tir.Schedule(before_transform_block_layout)
sch.transform_block_layout(sch.get_block("B"), lambda i, j: (i * 16 + j,))
print(sch.mod["main"].script())
After applying transform_block_layout, the IR becomes:
.. code-block:: python
@T.prim_func
def after_transform_block_layout(
A: T.Buffer[(16, 16), "float32"],
B: T.Buffer[(16, 16), "float32"]
) -> None:
for i in range(256):
with T.block("B"):
vi, = T.axis.remap("S", [i])
B[vi // 16, vi % 16] = A[vi // 16, vi % 16] * 2.0
"""
if callable(index_map):
index_map = IndexMap.from_func(index_map)
_ffi_api.ScheduleTransformBlockLayout( # type: ignore # pylint: disable=no-member
self, block, index_map
)

@type_checked
def set_axis_separator(
self,
Expand Down
11 changes: 11 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,17 @@ bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize,
std::unordered_set<const VarNode*>* data_par_vars,
std::unordered_set<const VarNode*>* reduce_vars);

/******** Loop properties ********/
/*!
* \brief Check the loop starts with zero.
* \param self The schedule state
* \param loop_sref The StmtSRef that points to the loop to be checked
* \param analyzer The arithmetic analyzer
* \throw ScheduleError If the loop doesn't starts with zero.
*/
void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef& loop_sref,
arith::Analyzer* analyzer);

/******** Block-loop relation ********/

/*!
Expand Down
29 changes: 29 additions & 0 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,35 @@ bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize,
return has_block_vars_of_other_types;
}

/******** Loop properties ********/

void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef& loop_sref,
arith::Analyzer* analyzer) {
class LoopNotStartWithZeroError : public ScheduleError {
public:
explicit LoopNotStartWithZeroError(IRModule mod, For loop)
: mod_(mod), loop_(std::move(loop)) {}

String FastErrorString() const final {
return "ScheduleError: The primitive only supports loop starting with 0";
}

String DetailRenderTemplate() const final {
return "The loop {0} does not start with 0, which is not supported";
}

IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {loop_}; }

IRModule mod_;
For loop_;
};
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
if (!analyzer->CanProve(loop->min == 0)) {
throw LoopNotStartWithZeroError(self->mod, GetRef<For>(loop));
}
}

/******** Block-loop relation ********/

Array<StmtSRef> GetChildBlockSRefOnSRefTree(const ScheduleState& self,
Expand Down
8 changes: 8 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,14 @@ void ConcreteScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_i
TVM_TIR_SCHEDULE_END("transform_layout", this->error_render_level_);
}

void ConcreteScheduleNode::TransformBlockLayout(const BlockRV& block_rv,
const IndexMap& index_map) {
TVM_TIR_SCHEDULE_BEGIN();
tir::TransformBlockLayout(state_, this->GetSRef(block_rv), index_map);
this->state_->DebugVerify();
TVM_TIR_SCHEDULE_END("transform_block_layout", this->error_render_level_);
}

void ConcreteScheduleNode::SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type,
const Array<IntImm>& axis_separators) {
Expand Down
1 change: 1 addition & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ class ConcreteScheduleNode : public ScheduleNode {
/******** Schedule: Layout transformation ********/
void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type,
const IndexMap& index_map) override;
void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) override;
void SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type,
const Array<IntImm>& axis_separators) override;
Expand Down
12 changes: 12 additions & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,18 @@ TVM_DLL void Unannotate(ScheduleState self, const StmtSRef& sref, const String&
TVM_DLL void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
BufferIndexType buffer_index_type, const IndexMap& index_map);

/*!
* \brief Apply a transformation represented by IndexMap to block
* \details The block iters and the block body are transformed by the given index_map.
* Outer loops corresponding to each new block iter are regenerated.
* The index_map is required to be bijective affine since we need its inverse mapping.
* \param self The state of the schedule
* \param block_sref The block sref that refers to the block to be transformed
* \param index_map The transformation to apply.
*/
TVM_DLL void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref,
const IndexMap& index_map);

/******** Schedule: Misc ********/

} // namespace tir
Expand Down
Loading

0 comments on commit d4a3968

Please sign in to comment.