Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix IndexDataTypeNormalizer to avoid redundant casting #13449

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/tvm/tir/data_type_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ class IndexDataTypeNormalizer : public IndexDataTypeRewriter {
using Parent::VisitStmt_;
PrimExpr VisitExpr_(const IntImmNode* op) final;
PrimExpr VisitExpr_(const VarNode* op) final;
PrimExpr VisitExpr_(const CastNode* op) final;

DataType target_data_type_ = DataType::Int(64);
};
Expand Down
2 changes: 1 addition & 1 deletion src/te/operation/create_primfunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ PrimFunc GenerateAndCompletePrimFunc(const Array<te::Tensor>& arg_list,
PrimFunc CreatePrimFuncWithConstants(const Array<te::Tensor>& arg_list,
const Array<runtime::NDArray>& constants,
std::optional<DataType> index_dtype_override) {
// Infomations used in CreatePrimFunc and its sub-functions.
// Information used in CreatePrimFunc and its sub-functions.
CreateFuncInfo info(arg_list);
// Root body stmts.
Array<Stmt> root_stmts;
Expand Down
11 changes: 10 additions & 1 deletion src/tir/ir/data_type_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,7 @@ PrimExpr IndexDataTypeRewriter::VisitExpr_(const CallNode* op) {

IndexDataTypeNormalizer::IndexDataTypeNormalizer(DataType target_data_type)
: target_data_type_(std::move(target_data_type)) {}

PrimFunc IndexDataTypeNormalizer::Rewrite(PrimFunc func) {
Map<Var, Buffer> new_buffer_map = func->buffer_map;
for (const auto& [var, buffer] : func->buffer_map) {
Expand All @@ -534,13 +535,21 @@ PrimExpr IndexDataTypeNormalizer::VisitExpr_(const VarNode* op) {
if (auto it = var_remap_.find(GetRef<Var>(op)); it != var_remap_.end()) {
return (*it).second;
}
if (is_enabled_) {
if (is_enabled_ && op->dtype != target_data_type_) {
Var new_var = GetRef<Var>(op).copy_with_dtype(target_data_type_);
var_remap_.Set(GetRef<Var>(op), new_var);
return std::move(new_var);
}
return GetRef<PrimExpr>(op);
}

PrimExpr IndexDataTypeNormalizer::VisitExpr_(const CastNode* op) {
if (is_enabled_) {
PrimExpr value = IndexDataTypeNormalizer::VisitExpr(op->value);
return value->dtype == target_data_type_ ? value : Cast(target_data_type_, value);
}
return IndexDataTypeRewriter::VisitExpr_(op);
}

} // namespace tir
} // namespace tvm
36 changes: 34 additions & 2 deletions tests/python/unittest/test_te_create_primfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ def test_unique_name_reduction_block():

def _check_workload(te_workload, tir_workload, index_dtype_override=None):
func = te.create_prim_func(te_workload(), index_dtype_override)
print(func.script())
print(tvm.ir.base.get_first_structural_mismatch(func, tir_workload))
tvm.ir.assert_structural_equal(func, tir_workload)
# make sure that we can create schedule from the func
s = tir.Schedule(func, debug_mask="all")
Expand Down Expand Up @@ -575,6 +573,39 @@ def expected(
_check_workload(te_func, expected)


def te_reshape():
# The following is possible to be generated by TOPI. So we test this case.
A = te.placeholder((tvm.tir.IntImm("int64", 2), tvm.tir.IntImm("int64", 4)), name="A")
B = topi.reshape(A, (4, 2))
return [A, B]


@T.prim_func
def tir_reshape(
A: T.Buffer[(T.int64(2), T.int64(4)), "float32"],
T_reshape: T.Buffer[(T.int64(4), T.int64(2)), "float32"],
):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
for i0, i1 in T.grid(T.int64(4), T.int64(2)):
with T.block("T_reshape"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(
A[
(ax0 * T.int64(2) + ax1) % T.int64(8) // T.int64(4),
(ax0 * T.int64(2) + ax1) % T.int64(4),
]
)
T.writes(T_reshape[ax0, ax1])
T_reshape[ax0, ax1] = A[
(ax0 * T.int64(2) + ax1) % T.int64(8) // T.int64(4),
(ax0 * T.int64(2) + ax1) % T.int64(4),
]


def test_reshape():
_check_workload(te_reshape, tir_reshape, index_dtype_override="int64")


if __name__ == "__main__":
test_unique_name_complete_block()
test_unique_name_reduction_block()
Expand All @@ -593,3 +624,4 @@ def expected(
test_argmax_val_idx()
test_int64_indices()
test_zero_dim_add()
test_reshape()