From aa68f04d8743da7a3a603347f47c163198d423a7 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 14 Dec 2021 14:49:58 -0500 Subject: [PATCH] [TIR][Schedule] Add Annotate/Unannotate primitive Co-authored-by: Siyuan Feng Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Ruihang Lai Co-authored-by: Junru Shao Co-authored-by: Wuwei Lin Co-authored-by: Xiyou Zhou --- include/tvm/tir/schedule/schedule.h | 28 +++ python/tvm/script/tir/scope_handler.py | 4 +- python/tvm/script/tir/special_stmt.py | 4 +- python/tvm/tir/schedule/schedule.py | 121 ++++++++++++- src/tir/schedule/concrete_schedule.cc | 47 +++++ src/tir/schedule/concrete_schedule.h | 12 ++ src/tir/schedule/primitive.h | 17 ++ src/tir/schedule/primitive/annotate.cc | 168 ++++++++++++++++++ src/tir/schedule/schedule.cc | 26 +++ src/tir/schedule/traced_schedule.cc | 38 ++++ src/tir/schedule/traced_schedule.h | 4 + .../unittest/test_tir_schedule_utilities.py | 66 +++++++ 12 files changed, 529 insertions(+), 6 deletions(-) create mode 100644 src/tir/schedule/primitive/annotate.cc diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index ffd860d84cf31..8646bf361f019 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -450,6 +450,34 @@ class ScheduleNode : public runtime::Object { int offset) = 0; /******** Schedule: Blockize & Tensorize ********/ /******** Schedule: Annotation ********/ + /*! + * \brief Annotate a loop with a key value pair + * \param loop The loop to be annotated + * \param ann_key The annotation key + * \param ann_val The annotation value, a string or a ExprRV + */ + virtual void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) = 0; + /*! + * \brief Annotate a block with a key value pair + * \param loop The block to be annotated + * \param ann_key The annotation key + * \param ann_val The annotation value, a string or a ExprRV + */ + virtual void Annotate(const BlockRV& block_rv, const String& ann_key, + const ObjectRef& ann_val) = 0; + /*! + * \brief Unannotate a loop's annotation with key ann_key + * \param loop The loop to be unannotated + * \param ann_key The annotation key + */ + virtual void Unannotate(const LoopRV& loop_rv, const String& ann_key) = 0; + /*! + * \brief Unannotate a block's annotation with key ann_key + * \param loop The block to be unannotated + * \param ann_key The annotation key + */ + virtual void Unannotate(const BlockRV& block_rv, const String& ann_key) = 0; + /******** Schedule: Misc ********/ /*! \brief A no-op that marks the start of postprocessing phase of scheduling */ virtual void EnterPostproc() = 0; diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py index 42f84bc40f60f..db3261e7a3920 100644 --- a/python/tvm/script/tir/scope_handler.py +++ b/python/tvm/script/tir/scope_handler.py @@ -20,7 +20,7 @@ import synr import tvm.tir -from tvm.runtime import Object +from tvm.runtime import Object, String from tvm.ir import Span, Range from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind @@ -486,7 +486,7 @@ def create_loop_info( self.annotations: Mapping[str, Object] = {} if annotations is not None: self.annotations = { - key: tvm.tir.StringImm(val) if isinstance(val, str) else val + key: String(val) if isinstance(val, str) else val for key, val in annotations.items() } diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index d7b09341ab9b4..cc5a9ab85e92e 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -24,7 +24,7 @@ from tvm.ir.expr import PrimExpr, Range import tvm.tir -from tvm.runtime import Object +from tvm.runtime import Object, String from tvm import te from tvm.target import Target from tvm.ir import Span @@ -430,7 +430,7 @@ def block_attr(attrs: Mapping[str, Object], span: Span = None): span, ) attrs = { - key: tvm.tir.StringImm(val) if isinstance(val, str) else val + key: String(val) if isinstance(val, str) else val for key, val in attrs.items() } block_scope.annotations = attrs diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 3a1d714956728..334f0dfb4d29c 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -20,8 +20,8 @@ from tvm._ffi import register_object as _register_object from tvm.error import TVMError, register_error from tvm.ir import IRModule, PrimExpr -from tvm.runtime import Object -from tvm.tir import Block, For, IntImm, PrimFunc +from tvm.runtime import Object, String +from tvm.tir import Block, FloatImm, For, IntImm, PrimFunc from . import _ffi_api from .state import ScheduleState, StmtSRef, _parse_debug_mask, _parse_mod @@ -1664,6 +1664,123 @@ def after_storage_align(a: T.handle, c: T.handle) -> None: ########## Schedule: Annotation ########## + def annotate( + self, + block_or_loop: Union[BlockRV, LoopRV], + ann_key: str, + ann_val: Union[str, int, float, ExprRV], + ) -> None: + """Annotate a block/loop with a key value pair + + Parameters + ---------- + block_or_loop: Union[BlockRV, LoopRV] + The block/loop to be annotated + ann_key : str + The annotation key + ann_val : Union[str, int, float, ExprRV] + The annotation value + + Examples + -------- + + Before annotate, in TensorIR, the IR is: + + .. code-block:: python + + @T.prim_func + def before_annotate(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and do annotate: + + .. code-block:: python + + sch = tir.Schedule(before_annotate) + sch.annotate(sch.get_block("B"), "ann_key", "ann_value") + print(sch.mod["main"].script()) + + After applying annotate, the IR becomes: + + .. code-block:: python + + @T.prim_func + def after_annotate(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.block_attr({"ann_key", "ann_value"}) + B[vi, vj] = A[vi, vj] * 2.0 + + """ + if isinstance(ann_val, str): + ann_val = String(ann_val) + elif isinstance(ann_val, int): + ann_val = IntImm("int32", ann_val) + elif isinstance(ann_val, float): + ann_val = FloatImm("float32", ann_val) + _ffi_api.ScheduleAnnotate( # pylint: disable=no-member + self, block_or_loop, ann_key, ann_val + ) + + def unannotate(self, block_or_loop: Union[BlockRV, LoopRV], ann_key: str) -> None: + """Unannotate a block/loop's annotation with key ann_key + + Parameters + ---------- + block_or_loop: Union[BlockRV, LoopRV] + The block/loop to be unannotated + ann_key : str + The annotation key + + Examples + -------- + + Before unannotate, in TensorIR, the IR is: + + .. code-block:: python + + @T.prim_func + def before_unannotate(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.block_attr({"ann_key", "ann_value"}) + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and do annotate: + + .. code-block:: python + + sch = tir.Schedule(before_unannotate) + sch.unannotate(sch.get_block("B"), "ann_key") + print(sch.mod["main"].script()) + + After applying unannotate, the IR becomes: + + .. code-block:: python + + @T.prim_func + def after_unannotate(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + + """ + _ffi_api.ScheduleUnannotate(self, block_or_loop, ann_key) # pylint: disable=no-member + ########## Schedule: Misc ########## @type_checked diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 4db4cd4ba1c82..3afaa833588bc 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -563,6 +563,53 @@ BlockRV ConcreteScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) { /******** Schedule: Blockize & Tensorize ********/ /******** Schedule: Annotation ********/ + +ObjectRef ConcreteScheduleNode::CheckAndGetAnnotationValue(const ObjectRef& ann_val) { + if (ann_val.as()) { + return ann_val; + } + if (const auto* expr = ann_val.as()) { + ICHECK(!ann_val->IsInstance()) + << "TypeError: runtime::String is expected, but gets StringImm"; + return this->Get(GetRef(expr)); + } + LOG(FATAL) + << "TypeError: Only strings, integers, floats, ExprRVs and Arrays are supported for now, but " + << "gets: " << ann_val->GetTypeKey(); + throw; +} + +void ConcreteScheduleNode::Annotate(const LoopRV& loop_rv, const String& ann_key, + const ObjectRef& ann_val) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::Annotate(state_, this->GetSRef(loop_rv), ann_key, this->CheckAndGetAnnotationValue(ann_val)); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("annotate", this->error_render_level_); +} + +void ConcreteScheduleNode::Unannotate(const LoopRV& loop_rv, const String& ann_key) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::Unannotate(state_, this->GetSRef(loop_rv), ann_key); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("unannotate", this->error_render_level_); +} + +void ConcreteScheduleNode::Annotate(const BlockRV& block_rv, const String& ann_key, + const ObjectRef& ann_val) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::Annotate(state_, this->GetSRef(block_rv), ann_key, + this->CheckAndGetAnnotationValue(ann_val)); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("annotate", this->error_render_level_); +} + +void ConcreteScheduleNode::Unannotate(const BlockRV& loop_rv, const String& ann_key) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::Unannotate(state_, this->GetSRef(loop_rv), ann_key); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("unannotate", this->error_render_level_); +} + /******** Schedule: Misc ********/ } // namespace tir diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index e4df5f893ae98..a8cd510709814 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -120,6 +120,11 @@ class ConcreteScheduleNode : public ScheduleNode { int offset) override; /******** Schedule: Blockize & Tensorize ********/ /******** Schedule: Annotation ********/ + void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override; + void Unannotate(const LoopRV& loop_rv, const String& ann_key) override; + void Annotate(const BlockRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override; + void Unannotate(const BlockRV& loop_rv, const String& ann_key) override; + /******** Schedule: Misc ********/ void EnterPostproc() override {} @@ -161,6 +166,13 @@ class ConcreteScheduleNode : public ScheduleNode { inline Array CreateRV(const std::vector& value); /*! \brief Remove a random variable from the symbol table */ inline void RemoveFromSymbolTable(const ObjectRef& rv); + /*! + * \brief Check the annotation value is valid and look up the random variable. Raises an exception + * if the type of the annotation value is not allowed. + * \param The annotation value. + * \return The annotation value with random variables substituted with their values. + */ + ObjectRef CheckAndGetAnnotationValue(const ObjectRef& ann_val); }; // implementations diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index cc7e44d4df9e6..609de274d80f4 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -340,6 +340,23 @@ TVM_DLL void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int bu /******** Schedule: Blockize & Tensorize ********/ /******** Schedule: Annotation ********/ +/*! + * \brief Annotate a block/loop with a key value pair + * \param self The state of the schedule + * \param sref The block/loop sref to be annotated + * \param ann_key The annotation key + * \param ann_val The annotation value + */ +TVM_DLL void Annotate(ScheduleState self, const StmtSRef& sref, const String& ann_key, + const ObjectRef& ann_val); +/*! + * \brief Unannotate a block/loop's annotation with key ann_key + * \param self The state of the schedule + * \param sref The block/loop to be unannotated + * \param ann_key The annotation key + */ +TVM_DLL void Unannotate(ScheduleState self, const StmtSRef& sref, const String& ann_key); + /******** Schedule: Misc ********/ } // namespace tir diff --git a/src/tir/schedule/primitive/annotate.cc b/src/tir/schedule/primitive/annotate.cc new file mode 100644 index 0000000000000..09b7a47e8ee85 --- /dev/null +++ b/src/tir/schedule/primitive/annotate.cc @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +void Annotate(ScheduleState self, const StmtSRef& sref, const String& ann_key, + const ObjectRef& ann_val) { + // Extract annotation + const Map* annotations = nullptr; + if (const auto* loop = sref->StmtAs()) { + annotations = &loop->annotations; + } else if (const auto* block = sref->StmtAs()) { + annotations = &block->annotations; + } else { + LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); + } + // Check if the annotation already exists + if (annotations->find(ann_key) != annotations->end()) { + return; + } + // Add the new annotation + Map new_ann(*annotations); + new_ann.Set(ann_key, ann_val); + // Create the new stmt + if (const auto* loop = sref->StmtAs()) { + ObjectPtr n = make_object(*loop); + n->annotations = std::move(new_ann); + self->Replace(sref, For(n), {}); + } else if (const auto* block = sref->StmtAs()) { + ObjectPtr n = make_object(*block); + n->annotations = std::move(new_ann); + Block p(n); + self->Replace(sref, p, {{GetRef(block), p}}); + } else { + LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); + throw; + } +} + +void Unannotate(ScheduleState self, const StmtSRef& sref, const String& ann_key) { + // Extract annotation + const Map* annotations = nullptr; + if (const auto* loop = sref->StmtAs()) { + annotations = &loop->annotations; + } else if (const auto* block = sref->StmtAs()) { + annotations = &block->annotations; + } else { + LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); + } + // Remove the annotation + ICHECK(annotations->find(ann_key) != annotations->end()) + << "IndexError: Cannot find annotation key: " << ann_key; + Map new_ann(*annotations); + new_ann.erase(ann_key); + // Create the new stmt + if (const auto* loop = sref->StmtAs()) { + ObjectPtr n = make_object(*loop); + n->annotations = std::move(new_ann); + self->Replace(sref, For(n), {}); + } else if (const auto* block = sref->StmtAs()) { + ObjectPtr n = make_object(*block); + n->annotations = std::move(new_ann); + Block p(n); + self->Replace(sref, p, {{GetRef(block), p}}); + } else { + LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); + throw; + } +} + +struct AnnotateTraits : public UnpackedInstTraits { + static constexpr const char* kName = "Annotate"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 2; + static constexpr size_t kNumAttrs = 1; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, ObjectRef ann_val, + String ann_key) { + if (const auto* block = block_or_loop_rv.as()) { + return sch->Annotate(GetRef(block), ann_key, ann_val); + } + if (const auto* loop = block_or_loop_rv.as()) { + return sch->Annotate(GetRef(loop), ann_key, ann_val); + } + LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: " << block_or_loop_rv->GetTypeKey(); + throw; + } + + static String UnpackedAsPython(Array outputs, ObjectRef block_or_loop_rv, + ObjectRef ann_val, String ann_key) { + PythonAPICall py("annotate"); + py.Input("block_or_loop", block_or_loop_rv); + py.Input("ann_key", ann_key); + if (const auto* int_imm = ann_val.as()) { + py.Input("ann_val", std::to_string(int_imm->value)); + } else if (const auto* str_imm = ann_val.as()) { + py.Input("ann_val", GetRef(str_imm)); + } else if (const auto* expr = ann_val.as()) { + std::ostringstream os; + os << GetRef(expr); + py.Input("ann_val", os.str()); + } else { + LOG(FATAL) << "TypeError: Cannot handle type: " << ann_val->GetTypeKey(); + throw; + } + return py.Str(); + } + + friend struct UnpackedInstTraits; +}; + +struct UnannotateTraits : public UnpackedInstTraits { + static constexpr const char* kName = "Unannotate"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 1; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, String ann_key) { + if (const auto* block = block_or_loop_rv.as()) { + return sch->Unannotate(GetRef(block), ann_key); + } + if (const auto* loop = block_or_loop_rv.as()) { + return sch->Unannotate(GetRef(loop), ann_key); + } + LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: " << block_or_loop_rv->GetTypeKey(); + throw; + } + + static String UnpackedAsPython(Array outputs, ObjectRef block_or_loop_rv, + String ann_key) { + PythonAPICall py("unannotate"); + py.Input("block_or_loop", block_or_loop_rv); + py.Input("ann_key", ann_key); + return py.Str(); + } + + friend struct UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(AnnotateTraits); +TVM_REGISTER_INST_KIND_TRAITS(UnannotateTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index a411e40b13b6c..0e98cb2172bb1 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -182,6 +182,32 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStorageAlign") .set_body_method(&ScheduleNode::StorageAlign); /******** (FFI) Blockize & Tensorize ********/ /******** (FFI) Annotation ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAnnotate") + .set_body_typed([](Schedule self, ObjectRef rv, const String& ann_key, + const ObjectRef& ann_val) { + if (const auto* block_rv = rv.as()) { + return self->Annotate(GetRef(block_rv), ann_key, ann_val); + } + if (const auto* loop_rv = rv.as()) { + return self->Annotate(GetRef(loop_rv), ann_key, ann_val); + } + LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey() + << ". Its value is: " << rv; + throw; + }); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnannotate") + .set_body_typed([](Schedule self, ObjectRef rv, const String& ann_key) { + if (const auto* block_rv = rv.as()) { + return self->Unannotate(GetRef(block_rv), ann_key); + } + if (const auto* loop_rv = rv.as()) { + return self->Unannotate(GetRef(loop_rv), ann_key); + } + LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey() + << ". Its value is: " << rv; + throw; + }); + /******** (FFI) Misc ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleEnterPostproc") .set_body_method(&ScheduleNode::EnterPostproc); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 4a028d1dad5c6..73ad1428d8b12 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -333,6 +333,44 @@ void TracedScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_index, /******** Schedule: Annotation ********/ +void TracedScheduleNode::Annotate(const LoopRV& loop_rv, const String& ann_key, + const ObjectRef& ann_val) { + ConcreteScheduleNode::Annotate(loop_rv, ann_key, ann_val); + static const InstructionKind& kind = InstructionKind::Get("Annotate"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rv, ann_val}, + /*attrs=*/{ann_key}, + /*outputs=*/{})); +} + +void TracedScheduleNode::Annotate(const BlockRV& block_rv, const String& ann_key, + const ObjectRef& ann_val) { + ConcreteScheduleNode::Annotate(block_rv, ann_key, ann_val); + static const InstructionKind& kind = InstructionKind::Get("Annotate"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv, ann_val}, + /*attrs=*/{ann_key}, + /*outputs=*/{})); +} + +void TracedScheduleNode::Unannotate(const LoopRV& loop_rv, const String& ann_key) { + ConcreteScheduleNode::Unannotate(loop_rv, ann_key); + static const InstructionKind& kind = InstructionKind::Get("Unannotate"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rv}, + /*attrs=*/{ann_key}, + /*outputs=*/{})); +} + +void TracedScheduleNode::Unannotate(const BlockRV& block_rv, const String& ann_key) { + ConcreteScheduleNode::Unannotate(block_rv, ann_key); + static const InstructionKind& kind = InstructionKind::Get("Unannotate"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{ann_key}, + /*outputs=*/{})); +} + /******** Schedule: Misc ********/ void TracedScheduleNode::EnterPostproc() { diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index ac36b9ca06a97..ea27f28cd2250 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -86,6 +86,10 @@ class TracedScheduleNode : public ConcreteScheduleNode { int offset) final; /******** Schedule: Blockize & Tensorize ********/ /******** Schedule: Annotation ********/ + void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override; + void Unannotate(const LoopRV& loop_rv, const String& ann_key) override; + void Annotate(const BlockRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override; + void Unannotate(const BlockRV& loop_rv, const String& ann_key) override; /******** Schedule: Misc ********/ void EnterPostproc() final; }; diff --git a/tests/python/unittest/test_tir_schedule_utilities.py b/tests/python/unittest/test_tir_schedule_utilities.py index d75bc1461c5e4..e01d469d8ec57 100644 --- a/tests/python/unittest/test_tir_schedule_utilities.py +++ b/tests/python/unittest/test_tir_schedule_utilities.py @@ -61,6 +61,46 @@ def matmul_relu(a: T.handle, b: T.handle, d: T.handle) -> None: D[vi, vj] = T.max(C[vi, vj], 0.0) +@T.prim_func +def matmul_relu_ann1(a: T.handle, b: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, (1024, 1024)) + B = T.match_buffer(b, (1024, 1024)) + C = T.alloc_buffer((1024, 1024)) + D = T.match_buffer(d, (1024, 1024)) + for i in T.serial(0, 1024, annotations={"test1": "aaa"}): + for j in T.serial(0, 1024, annotations={"test2": 612}): + for k in T.serial(0, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j in T.grid(1024, 1024): + with T.block("relu"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = T.max(C[vi, vj], 0.0) + + +@T.prim_func +def matmul_relu_ann2(a: T.handle, b: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, (1024, 1024)) + B = T.match_buffer(b, (1024, 1024)) + C = T.alloc_buffer((1024, 1024)) + D = T.match_buffer(d, (1024, 1024)) + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + T.block_attr({"test1": "aaa"}) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j in T.grid(1024, 1024): + with T.block("relu"): + vi, vj = T.axis.remap("SS", [i, j]) + T.block_attr({"test2": 0.22}) + D[vi, vj] = T.max(C[vi, vj], 0.0) + + # pylint: enable=no-member,invalid-name,unused-variable @@ -199,5 +239,31 @@ def test_get_consumers(): verify_trace_roundtrip(sch, mod=matmul_relu) +def test_annotate_unannotate_loop(): + sch = tir.Schedule(mod=matmul_relu, debug_mask="all") + matmul = sch.get_block("matmul") + relu = sch.get_block("relu") + sch.annotate(sch.get_loops(matmul)[0], "test1", "aaa") + sch.annotate(sch.get_loops(matmul)[1], "test2", 612) + tvm.ir.assert_structural_equal(sch.mod["main"], matmul_relu_ann1) + verify_trace_roundtrip(sch=sch, mod=matmul_relu) + sch.unannotate(sch.get_loops(matmul)[0], "test1") + sch.unannotate(sch.get_loops(matmul)[1], "test2") + verify_trace_roundtrip(sch=sch, mod=matmul_relu) + + +def test_annotate_unannotate_block(): + sch = tir.Schedule(mod=matmul_relu, debug_mask="all") + matmul = sch.get_block("matmul") + relu = sch.get_block("relu") + sch.annotate(matmul, "test1", "aaa") + sch.annotate(relu, "test2", 0.22) + tvm.ir.assert_structural_equal(sch.mod["main"], matmul_relu_ann2) + verify_trace_roundtrip(sch=sch, mod=matmul_relu) + sch.unannotate(matmul, "test1") + sch.unannotate(relu, "test2") + verify_trace_roundtrip(sch=sch, mod=matmul_relu) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))