Skip to content

Commit

Permalink
[TIR][Schedule] Add Annotate/Unannotate primitive
Browse files Browse the repository at this point in the history
Co-authored-by: Siyuan Feng <[email protected]>
Co-authored-by: Bohan Hou <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Wuwei Lin <[email protected]>
Co-authored-by: Xiyou Zhou <[email protected]>
  • Loading branch information
7 people committed Dec 14, 2021
1 parent ceec0fc commit aa68f04
Show file tree
Hide file tree
Showing 12 changed files with 529 additions and 6 deletions.
28 changes: 28 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,34 @@ class ScheduleNode : public runtime::Object {
int offset) = 0;
/******** Schedule: Blockize & Tensorize ********/
/******** Schedule: Annotation ********/
/*!
* \brief Annotate a loop with a key value pair
* \param loop The loop to be annotated
* \param ann_key The annotation key
* \param ann_val The annotation value, a string or a ExprRV
*/
virtual void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) = 0;
/*!
* \brief Annotate a block with a key value pair
* \param loop The block to be annotated
* \param ann_key The annotation key
* \param ann_val The annotation value, a string or a ExprRV
*/
virtual void Annotate(const BlockRV& block_rv, const String& ann_key,
const ObjectRef& ann_val) = 0;
/*!
* \brief Unannotate a loop's annotation with key ann_key
* \param loop The loop to be unannotated
* \param ann_key The annotation key
*/
virtual void Unannotate(const LoopRV& loop_rv, const String& ann_key) = 0;
/*!
* \brief Unannotate a block's annotation with key ann_key
* \param loop The block to be unannotated
* \param ann_key The annotation key
*/
virtual void Unannotate(const BlockRV& block_rv, const String& ann_key) = 0;

/******** Schedule: Misc ********/
/*! \brief A no-op that marks the start of postprocessing phase of scheduling */
virtual void EnterPostproc() = 0;
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/script/tir/scope_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import synr
import tvm.tir
from tvm.runtime import Object
from tvm.runtime import Object, String
from tvm.ir import Span, Range
from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind

Expand Down Expand Up @@ -486,7 +486,7 @@ def create_loop_info(
self.annotations: Mapping[str, Object] = {}
if annotations is not None:
self.annotations = {
key: tvm.tir.StringImm(val) if isinstance(val, str) else val
key: String(val) if isinstance(val, str) else val
for key, val in annotations.items()
}

Expand Down
4 changes: 2 additions & 2 deletions python/tvm/script/tir/special_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from tvm.ir.expr import PrimExpr, Range

import tvm.tir
from tvm.runtime import Object
from tvm.runtime import Object, String
from tvm import te
from tvm.target import Target
from tvm.ir import Span
Expand Down Expand Up @@ -430,7 +430,7 @@ def block_attr(attrs: Mapping[str, Object], span: Span = None):
span,
)
attrs = {
key: tvm.tir.StringImm(val) if isinstance(val, str) else val
key: String(val) if isinstance(val, str) else val
for key, val in attrs.items()
}
block_scope.annotations = attrs
Expand Down
121 changes: 119 additions & 2 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from tvm._ffi import register_object as _register_object
from tvm.error import TVMError, register_error
from tvm.ir import IRModule, PrimExpr
from tvm.runtime import Object
from tvm.tir import Block, For, IntImm, PrimFunc
from tvm.runtime import Object, String
from tvm.tir import Block, FloatImm, For, IntImm, PrimFunc

from . import _ffi_api
from .state import ScheduleState, StmtSRef, _parse_debug_mask, _parse_mod
Expand Down Expand Up @@ -1664,6 +1664,123 @@ def after_storage_align(a: T.handle, c: T.handle) -> None:

########## Schedule: Annotation ##########

def annotate(
self,
block_or_loop: Union[BlockRV, LoopRV],
ann_key: str,
ann_val: Union[str, int, float, ExprRV],
) -> None:
"""Annotate a block/loop with a key value pair
Parameters
----------
block_or_loop: Union[BlockRV, LoopRV]
The block/loop to be annotated
ann_key : str
The annotation key
ann_val : Union[str, int, float, ExprRV]
The annotation value
Examples
--------
Before annotate, in TensorIR, the IR is:
.. code-block:: python
@T.prim_func
def before_annotate(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
B = T.match_buffer(b, (128, 128))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
Create the schedule and do annotate:
.. code-block:: python
sch = tir.Schedule(before_annotate)
sch.annotate(sch.get_block("B"), "ann_key", "ann_value")
print(sch.mod["main"].script())
After applying annotate, the IR becomes:
.. code-block:: python
@T.prim_func
def after_annotate(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
B = T.match_buffer(b, (128, 128))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.block_attr({"ann_key", "ann_value"})
B[vi, vj] = A[vi, vj] * 2.0
"""
if isinstance(ann_val, str):
ann_val = String(ann_val)
elif isinstance(ann_val, int):
ann_val = IntImm("int32", ann_val)
elif isinstance(ann_val, float):
ann_val = FloatImm("float32", ann_val)
_ffi_api.ScheduleAnnotate( # pylint: disable=no-member
self, block_or_loop, ann_key, ann_val
)

def unannotate(self, block_or_loop: Union[BlockRV, LoopRV], ann_key: str) -> None:
"""Unannotate a block/loop's annotation with key ann_key
Parameters
----------
block_or_loop: Union[BlockRV, LoopRV]
The block/loop to be unannotated
ann_key : str
The annotation key
Examples
--------
Before unannotate, in TensorIR, the IR is:
.. code-block:: python
@T.prim_func
def before_unannotate(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
B = T.match_buffer(b, (128, 128))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.block_attr({"ann_key", "ann_value"})
B[vi, vj] = A[vi, vj] * 2.0
Create the schedule and do annotate:
.. code-block:: python
sch = tir.Schedule(before_unannotate)
sch.unannotate(sch.get_block("B"), "ann_key")
print(sch.mod["main"].script())
After applying unannotate, the IR becomes:
.. code-block:: python
@T.prim_func
def after_unannotate(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
B = T.match_buffer(b, (128, 128))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
"""
_ffi_api.ScheduleUnannotate(self, block_or_loop, ann_key) # pylint: disable=no-member

########## Schedule: Misc ##########

@type_checked
Expand Down
47 changes: 47 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,53 @@ BlockRV ConcreteScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) {

/******** Schedule: Blockize & Tensorize ********/
/******** Schedule: Annotation ********/

ObjectRef ConcreteScheduleNode::CheckAndGetAnnotationValue(const ObjectRef& ann_val) {
if (ann_val.as<StringObj>()) {
return ann_val;
}
if (const auto* expr = ann_val.as<PrimExprNode>()) {
ICHECK(!ann_val->IsInstance<StringImmNode>())
<< "TypeError: runtime::String is expected, but gets StringImm";
return this->Get(GetRef<PrimExpr>(expr));
}
LOG(FATAL)
<< "TypeError: Only strings, integers, floats, ExprRVs and Arrays are supported for now, but "
<< "gets: " << ann_val->GetTypeKey();
throw;
}

void ConcreteScheduleNode::Annotate(const LoopRV& loop_rv, const String& ann_key,
const ObjectRef& ann_val) {
TVM_TIR_SCHEDULE_BEGIN();
tir::Annotate(state_, this->GetSRef(loop_rv), ann_key, this->CheckAndGetAnnotationValue(ann_val));
this->state_->DebugVerify();
TVM_TIR_SCHEDULE_END("annotate", this->error_render_level_);
}

void ConcreteScheduleNode::Unannotate(const LoopRV& loop_rv, const String& ann_key) {
TVM_TIR_SCHEDULE_BEGIN();
tir::Unannotate(state_, this->GetSRef(loop_rv), ann_key);
this->state_->DebugVerify();
TVM_TIR_SCHEDULE_END("unannotate", this->error_render_level_);
}

void ConcreteScheduleNode::Annotate(const BlockRV& block_rv, const String& ann_key,
const ObjectRef& ann_val) {
TVM_TIR_SCHEDULE_BEGIN();
tir::Annotate(state_, this->GetSRef(block_rv), ann_key,
this->CheckAndGetAnnotationValue(ann_val));
this->state_->DebugVerify();
TVM_TIR_SCHEDULE_END("annotate", this->error_render_level_);
}

void ConcreteScheduleNode::Unannotate(const BlockRV& loop_rv, const String& ann_key) {
TVM_TIR_SCHEDULE_BEGIN();
tir::Unannotate(state_, this->GetSRef(loop_rv), ann_key);
this->state_->DebugVerify();
TVM_TIR_SCHEDULE_END("unannotate", this->error_render_level_);
}

/******** Schedule: Misc ********/

} // namespace tir
Expand Down
12 changes: 12 additions & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ class ConcreteScheduleNode : public ScheduleNode {
int offset) override;
/******** Schedule: Blockize & Tensorize ********/
/******** Schedule: Annotation ********/
void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override;
void Unannotate(const LoopRV& loop_rv, const String& ann_key) override;
void Annotate(const BlockRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override;
void Unannotate(const BlockRV& loop_rv, const String& ann_key) override;

/******** Schedule: Misc ********/
void EnterPostproc() override {}

Expand Down Expand Up @@ -161,6 +166,13 @@ class ConcreteScheduleNode : public ScheduleNode {
inline Array<ExprRV> CreateRV(const std::vector<int64_t>& value);
/*! \brief Remove a random variable from the symbol table */
inline void RemoveFromSymbolTable(const ObjectRef& rv);
/*!
* \brief Check the annotation value is valid and look up the random variable. Raises an exception
* if the type of the annotation value is not allowed.
* \param The annotation value.
* \return The annotation value with random variables substituted with their values.
*/
ObjectRef CheckAndGetAnnotationValue(const ObjectRef& ann_val);
};

// implementations
Expand Down
17 changes: 17 additions & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,23 @@ TVM_DLL void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int bu

/******** Schedule: Blockize & Tensorize ********/
/******** Schedule: Annotation ********/
/*!
* \brief Annotate a block/loop with a key value pair
* \param self The state of the schedule
* \param sref The block/loop sref to be annotated
* \param ann_key The annotation key
* \param ann_val The annotation value
*/
TVM_DLL void Annotate(ScheduleState self, const StmtSRef& sref, const String& ann_key,
const ObjectRef& ann_val);
/*!
* \brief Unannotate a block/loop's annotation with key ann_key
* \param self The state of the schedule
* \param sref The block/loop to be unannotated
* \param ann_key The annotation key
*/
TVM_DLL void Unannotate(ScheduleState self, const StmtSRef& sref, const String& ann_key);

/******** Schedule: Misc ********/

} // namespace tir
Expand Down
Loading

0 comments on commit aa68f04

Please sign in to comment.