diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 22febfdfedec..215e330a4c6f 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -584,13 +584,23 @@ class ScheduleNode : public runtime::Object { virtual void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, int offset) = 0; /*! - * \brief Set the storage scope of a buffer, where the buffer is specified by the a block and a + * \brief Set the storage scope of a buffer, where the buffer is specified by a block and a * write-index * \param block_rv The producer block of the buffer * \param buffer_index The index of the buffer in block's write region * \param storage_scope The storage scope to be set */ virtual void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) = 0; + /*! + * \brief Set the data type of a buffer, where the buffer is specified by a block and a + * write-index + * \note This schedule primitive is unsafe and may change correctness of program because of + * type conversion, please use with caution. + * \param block_rv The producer block of the buffer + * \param buffer_index the index of the buffer in block's write region + * \param dtype The data type to be set + */ + virtual void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const String& dtype) = 0; /******** Schedule: Blockize & Tensorize ********/ /*! * \brief Convert the subtree rooted at a specific loop into a block. diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 896e2fc48e72..9269acdd78ba 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -2322,7 +2322,7 @@ def after_storage_align(a: T.handle, c: T.handle) -> None: @type_checked def set_scope(self, block: Union[BlockRV, str], buffer_index: int, storage_scope: str) -> None: """Set the storage scope of a buffer, where the buffer is - specified by the a block and a write-index + specified by the a block and a write-index. Parameters ---------- @@ -2384,13 +2384,88 @@ def after_set_scope( Note ---- - Set_scope requires the buffer to be an intermediate buffer defined via `alloc_buffer`. + `set_scope` requires the buffer to be an intermediate buffer defined via `alloc_buffer`. """ block = self._normalize_block_arg(block) _ffi_api.ScheduleSetScope( # type: ignore # pylint: disable=no-member self, block, buffer_index, storage_scope ) + @type_checked + def unsafe_set_dtype(self, block: Union[BlockRV, str], buffer_index: int, dtype: str) -> None: + """Set the data type of a buffer, where the buffer is + specified by the a block and write-index. + + This schedule primitive is unsafe and may change the correctness of program because of + type conversion, please use with caution. + + Parameters + ---------- + block : Union[BlockRV, str] + The producer block of the buffer + buffer_index : int + The index of the buffer in block's write region + dtype : str + The data type to be set + + Examples + -------- + + Before set_dtype, in TensorIR, the IR is: + + .. code-block:: python + + @T.prim_func + def before_set_dtype( + A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32") + ) -> None: + B = T.alloc_buffer((128, 128), dtype="float32") + + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j] + C[vi, vj] = B[vi, vj] + 1.0 + + Create the schedule and do set_dtype: + + .. code-block:: python + + sch = tir.Schedule(before_set_dtype) + sch.set_dtype("B", buffer_index=0, dtype="float16") + print(sch.mod["main"].script()) + + After applying set_dtype, the IR becomes: + + .. code-block:: python + + @T.prim_func + def after_set_dtype( + A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32") + ) -> None: + B = T.alloc_buffer((128, 128), dtype="float16") + + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = T.cast(A[vi, vj] * 2.0, "float16") + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j] + C[vi, vj] = T.cast(B[vi, vj], "float32") + 1.0 + + Note + ---- + `set_dtype` requires the buffer to be an intermediate buffer defined via `alloc_buffer`. + """ + block = self._normalize_block_arg(block) + _ffi_api.ScheduleUnsafeSetDType( # type: ignore # pylint: disable=no-member + self, block, buffer_index, dtype + ) + ########## Schedule: Blockize & Tensorize ########## @type_checked diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 5a9dab4854bd..330486b86ba2 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -701,6 +701,14 @@ void ConcreteScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index, this->state_->DebugVerify(); } +void ConcreteScheduleNode::UnsafeSetDType(const BlockRV& block_rv, int buffer_index, + const String& dtype) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::UnsafeSetDType(state_, this->GetSRef(block_rv), buffer_index, dtype); + TVM_TIR_SCHEDULE_END("set-dtype", this->error_render_level_); + this->state_->DebugVerify(); +} + /******** Schedule: Reduction ********/ BlockRV ConcreteScheduleNode::DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) { diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 82ac9f913374..93f094304bf4 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -141,6 +141,7 @@ class ConcreteScheduleNode : public ScheduleNode { void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, int offset) override; void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) override; + void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const String& dtype) override; /******** Schedule: Blockize & Tensorize ********/ BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) override; void Tensorize(const BlockRV& block_rv, const String& intrin, bool preserve_unit_iters) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 563864229a26..9c3540eb3d68 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -470,6 +470,18 @@ TVM_DLL void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int bu */ TVM_DLL void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index, const String& storage_scope); +/*! + * \brief Set the data type of a buffer, where the buffer is specified by a block and a + * write-index + * \note This schedule primitive is unsafe and may change correctness of program because of + * type conversion, please use with caution. + * \param self The state of the schedule + * \param block_sref The sref of the producer block of the buffer + * \param buffer_index The index of the buffer in block's write region + * \param dtype The data type to be set + */ +TVM_DLL void UnsafeSetDType(ScheduleState self, const StmtSRef& block_sref, int buffer_index, + const String& dtype); /*! * \brief Set the axis separator of a buffer, where the buffer is specified by a block and a read * or write index diff --git a/src/tir/schedule/primitive/block_annotate.cc b/src/tir/schedule/primitive/block_annotate.cc index 0912e36836e3..3f1789b3d6e6 100644 --- a/src/tir/schedule/primitive/block_annotate.cc +++ b/src/tir/schedule/primitive/block_annotate.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "../utils.h" namespace tvm { @@ -297,6 +299,93 @@ void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index, self->Replace(alloc_site_sref, new_block, block_reuse_map); } +/*! + * \brief A helper mutator which recursively mutates the old buffer's data type, inserts data type + * conversions, and collecte the block sref reuse information for the following replacement. + */ +class DTypeMutator : private ReplaceBufferMutator { + public: + /*! + * \param allocate_site The block where `old_buffer` was allocated. + * \param old_buffer The old buffer + * \param target_dtype The data type to be set + * \param block_sref_reuse The block sref reuse map to be updated + * \return The new block after the mutation + */ + static Block Mutate(const Block& allocate_site, const Buffer& old_buffer, const DataType& dtype, + Map* block_sref_reuse) { + Buffer new_buffer = WithDType(old_buffer, dtype); + DTypeMutator mutator(old_buffer, new_buffer, dtype, block_sref_reuse); + Stmt new_block = mutator.VisitStmt(allocate_site); + return Downcast(new_block); + } + + private: + DTypeMutator(const Buffer& old_buffer, Buffer new_buffer, const DataType& dtype, + Map* block_sref_reuse) + : ReplaceBufferMutator(old_buffer, std::move(new_buffer), block_sref_reuse), + src_dtype_(old_buffer->dtype), + tgt_dtype_(dtype) {} + + MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion& match_buffer) final { + auto it = buffer_var_map_.find(match_buffer->source->buffer->data.get()); + if (it != buffer_var_map_.end()) { + Buffer new_target_buffer = WithDType(match_buffer->buffer, it->second->dtype); + buffer_var_map_[match_buffer->buffer->data.get()] = new_target_buffer; + return MatchBufferRegion(new_target_buffer, + BufferRegion(it->second, match_buffer->source->region)); + } else { + return match_buffer; + } + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore node = Downcast(StmtExprMutator::VisitStmt_(op)); + auto it = buffer_var_map_.find(node->buffer->data.get()); + if (it != buffer_var_map_.end()) { + node.CopyOnWrite()->buffer = it->second; + node.CopyOnWrite()->value = Cast(tgt_dtype_, node->value); + } + return node; + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad node = Downcast(StmtExprMutator::VisitExpr_(op)); + auto it = buffer_var_map_.find(node->buffer->data.get()); + if (it != buffer_var_map_.end()) { + return Cast(src_dtype_, BufferLoad(it->second, node->indices)); + } + return node; + } + + DataType src_dtype_, tgt_dtype_; +}; + +void UnsafeSetDType(ScheduleState self, const StmtSRef& block_sref, int buffer_index, + const String& dtype) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + Buffer buffer = + GetNthAccessBuffer(self, GetRef(block), buffer_index, BufferIndexType::kWrite); + DataType target_dtype(runtime::String2DLDataType(dtype)); + + // Step 1. If `dtype` equals the original data type, just return. + if (buffer->dtype == target_dtype) { + return; + } + + // Step 2. Get the allocation site of the target buffer. + StmtSRef alloc_site_sref = + NonAllocatedBufferError::CheckAndGetBufferAllocationSite(self->mod, block_sref, buffer); + const BlockNode* alloc_site = TVM_SREF_TO_BLOCK(alloc_site_sref); + + // Step 3. Recursively replace old buffer to a new buffer, where the new buffer has the given + // dtype, and insert data type conversions. + Map block_reuse_map; + Block new_block = + DTypeMutator::Mutate(GetRef(alloc_site), buffer, target_dtype, &block_reuse_map); + self->Replace(alloc_site_sref, new_block, block_reuse_map); +} + /******** InstructionKind Registration ********/ struct StorageAlignTraits : public UnpackedInstTraits { @@ -356,8 +445,36 @@ struct SetScopeTraits : public UnpackedInstTraits { friend struct ::tvm::tir::UnpackedInstTraits; }; +struct UnsafeSetDTypeTraits : public UnpackedInstTraits { + static constexpr const char* kName = "UnsafeSetDType"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer buffer_index, + String dtype) { + return sch->UnsafeSetDType(block_rv, buffer_index->value, dtype); + } + + static String UnpackedAsPython(Array outputs, String block_rv, Integer buffer_index, + String dtype) { + PythonAPICall py("unsafe_set_dtype"); + py.Input("block", block_rv); + py.Input("buffer_index", buffer_index); + py.Input("dtype", dtype); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + TVM_REGISTER_INST_KIND_TRAITS(StorageAlignTraits); TVM_REGISTER_INST_KIND_TRAITS(SetScopeTraits); +TVM_REGISTER_INST_KIND_TRAITS(UnsafeSetDTypeTraits); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index cb8b5a1d7787..a3d5346f7fe1 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -211,6 +211,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStorageAlign") .set_body_method(&ScheduleNode::StorageAlign); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetScope") .set_body_method(&ScheduleNode::SetScope); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnsafeSetDType") + .set_body_method(&ScheduleNode::UnsafeSetDType); /******** (FFI) Blockize & Tensorize ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBlockize") .set_body_method(&ScheduleNode::Blockize); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index a5cb66a0cb44..2b3a3e54b5d3 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -475,6 +475,17 @@ void TracedScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index, /*outputs=*/{})); } +void TracedScheduleNode::UnsafeSetDType(const BlockRV& block_rv, int buffer_index, + const String& dtype) { + ConcreteScheduleNode::UnsafeSetDType(block_rv, buffer_index, dtype); + static const InstructionKind& kind = InstructionKind::Get("UnsafeSetDType"); + trace_->Append(/*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{Integer(buffer_index), dtype}, + /*outputs=*/{})); +} + /******** Schedule: Blockize & Tensorize ********/ BlockRV TracedScheduleNode::Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) { diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 1fcba9806380..e59dc564aadb 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -100,6 +100,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, int offset) final; void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) final; + void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const String& dtype) final; /******** Schedule: Blockize & Tensorize ********/ BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) final; void Tensorize(const BlockRV& block_rv, const String& intrin, bool preserve_unit_iters) final; diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index e91c5d142c04..baa7f44bbcf2 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -43,6 +43,16 @@ Buffer WithScope(const Buffer& buffer, const String& scope) { return Buffer(new_buffer); } +Buffer WithDType(const Buffer& buffer, const DataType& dtype) { + ObjectPtr new_buffer = make_object(*buffer.get()); + new_buffer->dtype = dtype; + const auto* ptr_type = TVM_TYPE_AS(buffer->data->type_annotation, PointerTypeNode); + new_buffer->data = + Var(buffer->data->name_hint, PointerType(PrimType(dtype), ptr_type->storage_scope)); + new_buffer->name = buffer->name; + return Buffer(new_buffer); +} + Array ReplaceBuffer(Array regions, const Buffer& source, const Buffer& target) { regions.MutateByApply([&source, &target](BufferRegion region) -> BufferRegion { diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h index 3593d6b9a444..d2412436c7fb 100644 --- a/src/tir/schedule/transform.h +++ b/src/tir/schedule/transform.h @@ -53,6 +53,14 @@ Block WithAnnotation(const BlockNode* block, const String& attr_key, const Objec */ Buffer WithScope(const Buffer& buffer, const String& scope); +/*! + * \brief Create a new buffer by changint the data type. + * \param buffer The given buffer. + * \param scope The target data type. + * \return The new buffer with target data type. + */ +Buffer WithDType(const Buffer& buffer, const DataType& dtype); + /*! * \brief Replaces the buffer within the specific sequence of regions * \param regions The regions whose buffers are to be replaced @@ -131,9 +139,9 @@ class ReplaceBufferMutator : public StmtExprMutator { return node; } - Stmt VisitStmt_(const BufferStoreNode* op) final; + Stmt VisitStmt_(const BufferStoreNode* op) override; - PrimExpr VisitExpr_(const BufferLoadNode* op) final; + PrimExpr VisitExpr_(const BufferLoadNode* op) override; virtual MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion& match_buffer); diff --git a/tests/python/unittest/test_tir_schedule_set_dtype.py b/tests/python/unittest/test_tir_schedule_set_dtype.py new file mode 100644 index 000000000000..7f0900619b9b --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_set_dtype.py @@ -0,0 +1,125 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring + +import pytest +import tvm +import tvm.testing +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.testing import verify_trace_roundtrip + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg + +@T.prim_func +def element_wise(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")) -> None: + B = T.alloc_buffer((128, 128), dtype="float32") + + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + +@T.prim_func +def element_wise_set_dtype(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): + B = T.alloc_buffer((128, 128), "float16") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = T.cast(A[vi, vj] * 2.0, "float16") + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi, vj]) + T.writes(C[vi, vj]) + C[vi, vj] = T.cast(B[vi, vj], "float32") + 1.0 + +@T.prim_func +def element_wise_subregion_match(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")) -> None: + B = T.alloc_buffer((128, 128), dtype="float32") + + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B_subregion0 = T.match_buffer(B[vi, vj], [], offset_factor=1) + B_subregion0[()] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + B_subregion1 = T.match_buffer(B[vi, vj], [], offset_factor=1) + C[vi, vj] = B_subregion1[()] + 1.0 + + +@T.prim_func +def element_wise_subregion_match_set_dtype(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")) -> None: + B = T.alloc_buffer((128, 128), "float16") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B_subregion0 = T.match_buffer(B[vi, vj], (), "float16", offset_factor=1) + B_subregion0[()] = T.cast(A[vi, vj] * 2.0, "float16") + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi, vj]) + T.writes(C[vi, vj]) + B_subregion1 = T.match_buffer(B[vi, vj], (), "float16", offset_factor=1) + C[vi, vj] = T.cast(B_subregion1[()], "float32") + 1.0 + + +use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True}) + +def test_set_dtype(use_block_name): + func = element_wise + sch = tir.Schedule(func, debug_mask="all") + sch.unsafe_set_dtype("B" if use_block_name else sch.get_block("B"), 0, "float16") + tvm.ir.assert_structural_equal(element_wise_set_dtype, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=func) + +def test_set_dtype_fail_on_output_buffer(use_block_name): + func = element_wise + sch = tir.Schedule(func, debug_mask='all') + with pytest.raises(tvm.tir.ScheduleError): + sch.unsafe_set_dtype('C' if use_block_name else sch.get_block("C"), 0, "float16") + +def test_set_dtype_fail_on_index_out_of_bound(): + func = element_wise + sch = tir.Schedule(func, debug_mask='all') + with pytest.raises(tvm.tir.ScheduleError): + sch.unsafe_set_dtype(sch.get_block("B"), 1, "float64") + with pytest.raises(tvm.tir.ScheduleError): + sch.unsafe_set_dtype(sch.get_block("B"), -1, "float64") + +def test_set_dtype_subregion(): + func = element_wise_subregion_match + sch = tir.Schedule(func, debug_mask='all') + sch.unsafe_set_dtype(sch.get_block("B"), 0, "float16") + tvm.ir.assert_structural_equal(element_wise_subregion_match_set_dtype, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=func) + + +if __name__ == "__main__": + tvm.testing.main()