diff --git a/python/tvm/tir/schedule/analysis.py b/python/tvm/tir/schedule/analysis.py index f2fb7c4f3d1d..71ff024217c7 100644 --- a/python/tvm/tir/schedule/analysis.py +++ b/python/tvm/tir/schedule/analysis.py @@ -17,12 +17,16 @@ """Analysis used in TensorIR scheduling""" from typing import List, Optional +import tvm._ffi +from tvm.runtime import Object + from ..buffer import Buffer from ..stmt import For from ..expr import PrimExpr -from ..function import IndexMap +from ..function import IndexMap, PrimFunc from . import _ffi_api +from .schedule import Schedule, BlockRV def suggest_index_map( @@ -56,3 +60,30 @@ def suggest_index_map( loops, predicate, ) + + +@tvm._ffi.register_object("tir.schedule.TensorizeInfo") +class TensorizeInfo(Object): + """Necessary information used for tensorization.""" + + +def get_tensorize_loop_mapping( + sch: Schedule, block: BlockRV, desc_func: PrimFunc +) -> Optional[TensorizeInfo]: + """Establish a mapping between loops in a target block and an intrinsic description + + Parameters + ---------- + sch : Schedule + The schedule to be tensorized + block : BlockRV + The target block to match against + desc_func : PrimFunc + The prim func describing the computation to be tensorized + + Returns + ------- + tensorize_info : Optional[TensorizeInfo] + TensorizeInfo structure if a valid mapping is found, None otherwise + """ + return _ffi_api.GetTensorizeLoopMapping(sch, block, desc_func) # type: ignore diff --git a/python/tvm/tir/stmt_functor.py b/python/tvm/tir/stmt_functor.py index 56dc1c20c2b3..5bcf4ae802c7 100644 --- a/python/tvm/tir/stmt_functor.py +++ b/python/tvm/tir/stmt_functor.py @@ -58,6 +58,18 @@ def post_order_visit(stmt, fvisit): return _ffi_api.PostOrderVisit(stmt, fvisit) # type: ignore +def pre_order_visit(stmt, fvisit): + """Recursive pre-order visit on stmt AST, applying fvisit on each node. + If fvisit returns False, it won't visit the children of the node. + + Parameters + ---------- + fvisit: function of the signature Object -> bool + The visitor function. + """ + return _ffi_api.PreOrderVisit(stmt, fvisit) # type: ignore + + def substitute(node, vmap): """Substitute the var specified by vmap. diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index c4d7ad0f6c67..06933c2c0dcb 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -792,6 +792,10 @@ TVM_REGISTER_GLOBAL("tir.PostOrderVisit").set_body_typed([](ObjectRef node, Pack tir::PostOrderVisit(node, [f](const ObjectRef& n) { f(n); }); }); +TVM_REGISTER_GLOBAL("tir.PreOrderVisit").set_body_typed([](ObjectRef node, PackedFunc f) { + tir::PreOrderVisit(node, [f](const ObjectRef& n) { return f(n); }); +}); + TVM_REGISTER_GLOBAL("tir.Substitute") .set_body_typed([](ObjectRef node, Map vmap) -> ObjectRef { if (node->IsInstance()) { diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index b76d41326ff1..c9c3d72ae0b5 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -656,6 +656,39 @@ Array AnalyzeRegionLowerBound(const BufferRegion& region, const P const StmtSRef& dom_high_exclusive, arith::Analyzer* analyzer); +/*! \brief Necessary information used for tensorization */ +class TensorizeInfoNode : public Object { + public: + /*! \brief Maps loops in a target block to the ones in an intrinsic description */ + Map loop_map; + /*! \brief Maps loops in an intrinsic description to its index, outer to inner */ + Map desc_loop_indexer; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("loop_map", &loop_map); + v->Visit("desc_loop_indexer", &desc_loop_indexer); + } + + static constexpr const char* _type_key = "tir.schedule.TensorizeInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(TensorizeInfoNode, Object); +}; + +class TensorizeInfo : public ObjectRef { + public: + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorizeInfo, ObjectRef, TensorizeInfoNode); +}; + +/*! + * \brief Establish a mapping between loops in a target block and an intrinsic description + * \param self The schedule state to be tensorized + * \param block_sref The target block to match against + * \param desc_func The prim func describing the computation to be tensorized + * \return TensorizeInfo structure if a valid mapping is found, NullOpt otherwise + */ +Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, + const tir::StmtSRef& block_sref, + const tir::PrimFunc& desc_func); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 4a7ac401dd60..4777ee2657b3 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -16,6 +16,9 @@ * specific language governing permissions and limitations * under the License. */ +#include +#include + #include "../utils.h" namespace tvm { @@ -492,8 +495,7 @@ void CheckNotOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, } } -std::vector GetBlockVarTypes(const StmtSRef& block_sref) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); +std::vector GetBlockVarTypes(const BlockNode* block) { std::vector results; results.reserve(block->iter_vars.size()); for (const IterVar& iter_var : block->iter_vars) { @@ -502,6 +504,11 @@ std::vector GetBlockVarTypes(const StmtSRef& block_sref) { return results; } +std::vector GetBlockVarTypes(const StmtSRef& block_sref) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + return GetBlockVarTypes(block); +} + bool IsWriteCache(const StmtSRef& block_sref) { const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); if (block->writes.size() != 1) { @@ -2028,5 +2035,161 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // } } +TVM_REGISTER_NODE_TYPE(TensorizeInfoNode); + +Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, + const tir::StmtSRef& block_sref, + const tir::PrimFunc& desc_func) { + arith::Analyzer analyzer; + const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref); + // Step 1. Analyze desc_func, extract its block, loops and loop vars + const tir::BlockRealizeNode* desc_block = nullptr; + std::vector desc_loops; + std::unordered_set desc_loop_vars; + const auto* desc_scope_realize = desc_func->body.as(); + ICHECK(desc_scope_realize); + { + auto f_visit = [&desc_block, &desc_loops, &desc_loop_vars, + &analyzer](const ObjectRef& obj) -> bool { + // Extract the block + if (const auto* block = obj.as()) { + desc_block = block; + return false; + } + // Extract loops + if (const auto* loop = obj.as()) { + desc_loops.push_back(loop); + desc_loop_vars.insert(loop->loop_var.get()); + if (!analyzer.CanProve(loop->min == 0)) { + return false; + } + } + return true; + }; + tir::PostOrderVisit(desc_scope_realize->block->body, f_visit); + std::reverse(desc_loops.begin(), desc_loops.end()); + ICHECK(desc_block); + } + // Step 2. Collect loops from block_sref + const tir::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false); + const tir::BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref); + std::vector block_loops; + std::unordered_set block_loop_vars; + { + for (const tir::StmtSRefNode* loop_sref = block_sref->parent;; loop_sref = loop_sref->parent) { + const auto* loop = loop_sref->StmtAs(); + if (loop == nullptr || loop->body->IsInstance()) { + break; + } + block_loops.push_back(loop); + block_loop_vars.insert(loop->loop_var.get()); + if (!analyzer.CanProve(loop->min == 0)) { + return NullOpt; + } + } + std::reverse(block_loops.begin(), block_loops.end()); + } + // Step 3. Map from block loops to desc block loops + ObjectPtr ret = make_object(); + const int n_block_vars = block->iter_values.size(); + const int n_desc_vars = desc_block->iter_values.size(); + const int offset = n_block_vars - n_desc_vars; + + if (offset < 0) { + return NullOpt; + } + + const std::vector iter_types_block = GetBlockVarTypes(block_sref); + const std::vector iter_types_desc = GetBlockVarTypes(desc_block->block.get()); + + ICHECK(desc_loops.size() == static_cast(n_desc_vars)); + ICHECK(block_loops.size() == iter_types_block.size()); + + // We assume that the orders of iter_vars in the target and the desc block are consistent. + // Based on that assumption, the following logic supports arbitrary permutations of a loop order, + // such as + + // for k: + // for i: + // for j: + // C[i, j] += A[i, k] * B[k, j] + + // or + + // for i: + // for j: + // for k: + // C[i, j] += A[i, k] * B[k, j] + + int next_block_ind = block_loops.size() - 1; + for (int i_desc = n_desc_vars - 1; i_desc >= 0; --i_desc) { + // Step 3.1. Find the corresponding loop of the i_desc-th block var of desc + const PrimExpr& desc_bind = desc_block->iter_values[i_desc]; + const tir::ForNode* desc_loop = nullptr; + IterVarType iter_type_desc = iter_types_desc[i_desc]; + for (int i = 0, n = desc_loops.size(); i < n; ++i) { + // Check if desc_bind = loops[i]->loop_var + stuff-irrelevant-of-loop-vars + PrimExpr residual = analyzer.Simplify(desc_bind - desc_loops[i]->loop_var); + if (!UsesVar(residual, + [&desc_loop_vars](const VarNode* var) { return desc_loop_vars.count(var); })) { + desc_loop = desc_loops[i]; + iter_type_desc = iter_types_desc[i]; + break; + } + } + if (desc_loop == nullptr || desc_loop->extent.as() == nullptr) { + return NullOpt; + } + + const IntImmNode* int_desc_extent = desc_loop->extent.as(); + + // Step 3.2. Find the corresponding iter_value of the target block with a matching iterator type + PrimExpr block_bind; + for (int i = next_block_ind; i >= 0; --i) { + if (iter_types_block[i] == iter_type_desc) { + next_block_ind = i - 1; + block_bind = block->iter_values[i]; + break; + } + } + + if (!block_bind.defined()) return NullOpt; + + // Step 3.3. Find the corresponding loop of the target block + for (int i = 0, n = block_loops.size(); i < n; ++i) { + // Check if block_bind = block_loops[i]->loop_var + stuff-irrelevant-of-loop-vars + const tir::ForNode* block_loop = block_loops[i]; + const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop]; + // Skip i-th loop if it has already been mapped + if (ret->loop_map.find(block_loop_sref) != ret->loop_map.end()) continue; + + PrimExpr residual = analyzer.Simplify(block_bind - block_loops[i]->loop_var); + if (UsesVar(residual, + [&block_loop_vars](const VarNode* var) { return block_loop_vars.count(var); })) + continue; + + const IntImmNode* int_block_extent = block_loops[i]->extent.as(); + + // Check divisibility + if (!int_block_extent || int_block_extent->value % int_desc_extent->value != 0) { + return NullOpt; + } + + ret->loop_map.Set(block_loop_sref, GetRef(desc_loop)); + break; + } + } + + for (int i = 0, n = desc_loops.size(); i < n; ++i) { + ret->desc_loop_indexer.Set(GetRef(desc_loops[i]), Integer(i)); + } + return TensorizeInfo(ret); +} + +TVM_REGISTER_GLOBAL("tir.schedule.GetTensorizeLoopMapping") + .set_body_typed([](Schedule sch, BlockRV block, PrimFunc desc_func) { + return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), desc_func); + }); + } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py index 760b412ac804..10371d3ccaf1 100644 --- a/tests/python/unittest/test_tir_schedule_analysis.py +++ b/tests/python/unittest/test_tir_schedule_analysis.py @@ -17,18 +17,17 @@ # pylint: disable=missing-docstring from typing import List -from tvm.tir import ( - Evaluate, - For, - ForKind, - IndexMap, - Var, - decl_buffer, - floordiv, - floormod, -) +import tvm +from tvm.tir.tensor_intrin.x86 import dot_product_16x4_u8i8i32_desc + + +from tvm.tir import Evaluate, For, ForKind, IndexMap, Var, decl_buffer, floordiv, floormod, Schedule from tvm.tir.analysis import expr_deep_equal -from tvm.tir.schedule.analysis import suggest_index_map +from tvm.tir.schedule.analysis import suggest_index_map, get_tensorize_loop_mapping, TensorizeInfo +from tvm.script import tir as T +from tvm.tir.stmt_functor import pre_order_visit +from tvm.meta_schedule.testing import te_workload +from tvm.te import create_prim_func def _make_vars(*args: str) -> List[Var]: @@ -102,6 +101,168 @@ def test_suggest_index_map_bijective(): _assert_equal_index_map(index_map, expected_index_map) +@tvm.script.ir_module +class DenseVNNIModule: + @T.prim_func + def main( + placeholder: T.Buffer[(1024, 1024), "uint8"], + placeholder_1: T.Buffer[(64, 256, 16, 4), "int8"], + compute: T.Buffer[(1024, 1024), "int32"], + ) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + with T.block("root"): + T.reads() + T.writes() + for i0, i1, i2 in T.grid(1024, 1024, 1024): + with T.block("compute"): + i, j, k = T.axis.remap("SSR", [i0, i1, i2]) + T.reads(placeholder[i, k], placeholder_1[j // 16, k // 4, j % 16, k % 4]) + T.writes(compute[i, j]) + with T.init(): + compute[i, j] = 0 + compute[i, j] = compute[i, j] + T.cast(placeholder[i, k], "int32") * T.cast( + placeholder_1[j // 16, k // 4, j % 16, k % 4], "int32" + ) + + +@tvm.script.ir_module +class Conv2dNCHWcVNNIModule: + @T.prim_func + def main( + placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], + placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], + conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"], + ) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4): + with T.block("conv2d_NCHWc_int8"): + ( + n, + oc_chunk, + oh, + ow, + oc_block, + kh, + kw, + ic_outer, + ic_f_inner, + ic_s_inner, + ) = T.axis.remap("SSSSSRRRRR", [i0, i1, i2, i3, i4, i5, i6, i7, i8, i9]) + T.reads( + placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], + placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], + ) + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block]) + with T.init(): + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0 + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[ + n, oc_chunk, oh, ow, oc_block + ] + T.cast( + placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32" + ) * T.cast( + placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], + "int32", + ) + + +def collect_loops(prim_func): + loops = [] + + def callback(node): + if isinstance(node, tvm.tir.For): + loops.append(node) + return True + + pre_order_visit(prim_func.body, callback) + + return loops + + +def test_get_tensorize_loop_mapping_dense_vnni(): + s = Schedule(DenseVNNIModule) + block = s.get_block("compute") + + info = get_tensorize_loop_mapping(s, block, dot_product_16x4_u8i8i32_desc) + + assert isinstance(info, TensorizeInfo) + + desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items()) + + desc_loops = collect_loops(dot_product_16x4_u8i8i32_desc) + _, loop_j, loop_k = s.get_loops(block) + + assert desc_loops[0] in desc_loop_to_sref and desc_loops[1] in desc_loop_to_sref + assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(loop_j) + assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(loop_k) + + +def test_get_tensorize_loop_mapping_conv2d_nchwc_vnni(): + s = Schedule(Conv2dNCHWcVNNIModule) + block = s.get_block("conv2d_NCHWc_int8") + + info = get_tensorize_loop_mapping(s, block, dot_product_16x4_u8i8i32_desc) + + desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items()) + + desc_loops = collect_loops(dot_product_16x4_u8i8i32_desc) + + # i4 corresonds to the inner output channel axis of the NCHWc output tensor + # for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4): + _, _, _, _, i4, _, _, _, _, i9 = s.get_loops(block) + + assert desc_loops[0] in desc_loop_to_sref and desc_loops[1] in desc_loop_to_sref + assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(i4) + assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(i9) + + +def test_get_tensorize_loop_mapping_matmul_mma(): + @T.prim_func + def matmul_16x16x16xf16f16f16_desc( + A: T.Buffer((16, 16), "float16", align=128, offset_factor=1), + B: T.Buffer((16, 16), "float16", align=128, offset_factor=1), + C: T.Buffer((16, 16), "float16", align=128, offset_factor=1), + ) -> None: + with T.block("root"): + T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16]) + T.writes(C[0:16, 0:16]) + for i, j, k in T.grid(16, 16, 16): + with T.block("update"): + vii, vjj, vkk = T.axis.remap("SSR", [i, j, k]) + C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk] + + matmul = create_prim_func( + te_workload.matmul_relu( + n=512, + m=512, + k=512, + ) + ) + + s = Schedule(matmul) + block = s.get_block("C") + i0, i1, i2 = s.get_loops(block) + desc_loops = collect_loops(matmul_16x16x16xf16f16f16_desc) + + for do_reorder in [False, True]: + # Mapping should be invariant to the loop permutation + if do_reorder: + s.reorder(i2, i0, i1) + + info = get_tensorize_loop_mapping(s, block, matmul_16x16x16xf16f16f16_desc) + assert info is not None + desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items()) + + for i in range(3): + assert desc_loops[i] in desc_loop_to_sref + + assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(i0) + assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(i1) + assert s.get(desc_loop_to_sref[desc_loops[2]]) == s.get(i2) + + if __name__ == "__main__": test_suggest_index_map_simple() test_suggest_index_map_bijective() + test_get_tensorize_loop_mapping_dense_vnni() + test_get_tensorize_loop_mapping_conv2d_nchwc_vnni() + test_get_tensorize_loop_mapping_matmul_mma()