diff --git a/include/tvm/tir/data_type_rewriter.h b/include/tvm/tir/data_type_rewriter.h index 378addaba528..bf90aaedfec0 100644 --- a/include/tvm/tir/data_type_rewriter.h +++ b/include/tvm/tir/data_type_rewriter.h @@ -145,6 +145,7 @@ class IndexDataTypeNormalizer : public IndexDataTypeRewriter { using Parent::VisitStmt_; PrimExpr VisitExpr_(const IntImmNode* op) final; PrimExpr VisitExpr_(const VarNode* op) final; + PrimExpr VisitExpr_(const CastNode* op) final; DataType target_data_type_ = DataType::Int(64); }; diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 9f8d7d46a151..223f8dcd5dd0 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -496,7 +496,7 @@ PrimFunc GenerateAndCompletePrimFunc(const Array& arg_list, PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, const Array& constants, std::optional index_dtype_override) { - // Infomations used in CreatePrimFunc and its sub-functions. + // Information used in CreatePrimFunc and its sub-functions. CreateFuncInfo info(arg_list); // Root body stmts. Array root_stmts; diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc index fecb8e5fb70c..27a59d970981 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tir/ir/data_type_rewriter.cc @@ -511,6 +511,7 @@ PrimExpr IndexDataTypeRewriter::VisitExpr_(const CallNode* op) { IndexDataTypeNormalizer::IndexDataTypeNormalizer(DataType target_data_type) : target_data_type_(std::move(target_data_type)) {} + PrimFunc IndexDataTypeNormalizer::Rewrite(PrimFunc func) { Map new_buffer_map = func->buffer_map; for (const auto& [var, buffer] : func->buffer_map) { @@ -534,7 +535,7 @@ PrimExpr IndexDataTypeNormalizer::VisitExpr_(const VarNode* op) { if (auto it = var_remap_.find(GetRef(op)); it != var_remap_.end()) { return (*it).second; } - if (is_enabled_) { + if (is_enabled_ && op->dtype != target_data_type_) { Var new_var = GetRef(op).copy_with_dtype(target_data_type_); var_remap_.Set(GetRef(op), new_var); return std::move(new_var); @@ -542,5 +543,13 @@ PrimExpr IndexDataTypeNormalizer::VisitExpr_(const VarNode* op) { return GetRef(op); } +PrimExpr IndexDataTypeNormalizer::VisitExpr_(const CastNode* op) { + if (is_enabled_) { + PrimExpr value = IndexDataTypeNormalizer::VisitExpr(op->value); + return value->dtype == target_data_type_ ? value : Cast(target_data_type_, value); + } + return IndexDataTypeRewriter::VisitExpr_(op); +} + } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py index 6662c7aca85b..9a5326650184 100644 --- a/tests/python/unittest/test_te_create_primfunc.py +++ b/tests/python/unittest/test_te_create_primfunc.py @@ -46,8 +46,6 @@ def test_unique_name_reduction_block(): def _check_workload(te_workload, tir_workload, index_dtype_override=None): func = te.create_prim_func(te_workload(), index_dtype_override) - print(func.script()) - print(tvm.ir.base.get_first_structural_mismatch(func, tir_workload)) tvm.ir.assert_structural_equal(func, tir_workload) # make sure that we can create schedule from the func s = tir.Schedule(func, debug_mask="all") @@ -575,6 +573,39 @@ def expected( _check_workload(te_func, expected) +def te_reshape(): + # The following is possible to be generated by TOPI. So we test this case. + A = te.placeholder((tvm.tir.IntImm("int64", 2), tvm.tir.IntImm("int64", 4)), name="A") + B = topi.reshape(A, (4, 2)) + return [A, B] + + +@T.prim_func +def tir_reshape( + A: T.Buffer[(T.int64(2), T.int64(4)), "float32"], + T_reshape: T.Buffer[(T.int64(4), T.int64(2)), "float32"], +): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i0, i1 in T.grid(T.int64(4), T.int64(2)): + with T.block("T_reshape"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads( + A[ + (ax0 * T.int64(2) + ax1) % T.int64(8) // T.int64(4), + (ax0 * T.int64(2) + ax1) % T.int64(4), + ] + ) + T.writes(T_reshape[ax0, ax1]) + T_reshape[ax0, ax1] = A[ + (ax0 * T.int64(2) + ax1) % T.int64(8) // T.int64(4), + (ax0 * T.int64(2) + ax1) % T.int64(4), + ] + + +def test_reshape(): + _check_workload(te_reshape, tir_reshape, index_dtype_override="int64") + + if __name__ == "__main__": test_unique_name_complete_block() test_unique_name_reduction_block() @@ -593,3 +624,4 @@ def expected( test_argmax_val_idx() test_int64_indices() test_zero_dim_add() + test_reshape()