Skip to content

Commit

Permalink
[Fix] Fix IndexDataTypeNormalizer to avoid redundant casting
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterJH5574 committed Nov 21, 2022
1 parent 26d9b5a commit 48044ce
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 4 deletions.
1 change: 1 addition & 0 deletions include/tvm/tir/data_type_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ class IndexDataTypeNormalizer : public IndexDataTypeRewriter {
using Parent::VisitStmt_;
PrimExpr VisitExpr_(const IntImmNode* op) final;
PrimExpr VisitExpr_(const VarNode* op) final;
PrimExpr VisitExpr_(const CastNode* op) final;

DataType target_data_type_ = DataType::Int(64);
};
Expand Down
2 changes: 1 addition & 1 deletion src/te/operation/create_primfunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ PrimFunc GenerateAndCompletePrimFunc(const Array<te::Tensor>& arg_list,
PrimFunc CreatePrimFuncWithConstants(const Array<te::Tensor>& arg_list,
const Array<runtime::NDArray>& constants,
std::optional<DataType> index_dtype_override) {
// Infomations used in CreatePrimFunc and its sub-functions.
// Informations used in CreatePrimFunc and its sub-functions.
CreateFuncInfo info(arg_list);
// Root body stmts.
Array<Stmt> root_stmts;
Expand Down
11 changes: 10 additions & 1 deletion src/tir/ir/data_type_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,7 @@ PrimExpr IndexDataTypeRewriter::VisitExpr_(const CallNode* op) {

IndexDataTypeNormalizer::IndexDataTypeNormalizer(DataType target_data_type)
: target_data_type_(std::move(target_data_type)) {}

PrimFunc IndexDataTypeNormalizer::Rewrite(PrimFunc func) {
Map<Var, Buffer> new_buffer_map = func->buffer_map;
for (const auto& [var, buffer] : func->buffer_map) {
Expand All @@ -534,13 +535,21 @@ PrimExpr IndexDataTypeNormalizer::VisitExpr_(const VarNode* op) {
if (auto it = var_remap_.find(GetRef<Var>(op)); it != var_remap_.end()) {
return (*it).second;
}
if (is_enabled_) {
if (is_enabled_ && op->dtype != target_data_type_) {
Var new_var = GetRef<Var>(op).copy_with_dtype(target_data_type_);
var_remap_.Set(GetRef<Var>(op), new_var);
return std::move(new_var);
}
return GetRef<PrimExpr>(op);
}

PrimExpr IndexDataTypeNormalizer::VisitExpr_(const CastNode* op) {
if (is_enabled_) {
PrimExpr value = IndexDataTypeNormalizer::VisitExpr(op->value);
return value->dtype == target_data_type_ ? value : Cast(target_data_type_, value);
}
return IndexDataTypeRewriter::VisitExpr_(op);
}

} // namespace tir
} // namespace tvm
36 changes: 34 additions & 2 deletions tests/python/unittest/test_te_create_primfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ def test_unique_name_reduction_block():

def _check_workload(te_workload, tir_workload, index_dtype_override=None):
func = te.create_prim_func(te_workload(), index_dtype_override)
print(func.script())
print(tvm.ir.base.get_first_structural_mismatch(func, tir_workload))
tvm.ir.assert_structural_equal(func, tir_workload)
# make sure that we can create schedule from the func
s = tir.Schedule(func, debug_mask="all")
Expand Down Expand Up @@ -575,6 +573,39 @@ def expected(
_check_workload(te_func, expected)


def te_reshape():
# The following is possible to be generated by TOPI. So we test this case.
A = te.placeholder((tvm.tir.IntImm("int64", 2), tvm.tir.IntImm("int64", 4)), name="A")
B = topi.reshape(A, (4, 2))
return [A, B]


@T.prim_func
def tir_reshape(
A: T.Buffer[(T.int64(2), T.int64(4)), "float32"],
T_reshape: T.Buffer[(T.int64(4), T.int64(2)), "float32"],
):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
for i0, i1 in T.grid(T.int64(4), T.int64(2)):
with T.block("T_reshape"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(
A[
(ax0 * T.int64(2) + ax1) % T.int64(8) // T.int64(4),
(ax0 * T.int64(2) + ax1) % T.int64(4),
]
)
T.writes(T_reshape[ax0, ax1])
T_reshape[ax0, ax1] = A[
(ax0 * T.int64(2) + ax1) % T.int64(8) // T.int64(4),
(ax0 * T.int64(2) + ax1) % T.int64(4),
]


def test_reshape():
_check_workload(te_reshape, tir_reshape, index_dtype_override="int64")


if __name__ == "__main__":
test_unique_name_complete_block()
test_unique_name_reduction_block()
Expand All @@ -593,3 +624,4 @@ def expected(
test_argmax_val_idx()
test_int64_indices()
test_zero_dim_add()
test_reshape()

0 comments on commit 48044ce

Please sign in to comment.