From bf9ddb36a9a8d4bad397c68ff9abf1dc6ad343d6 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sun, 15 Jan 2023 14:41:43 -0500 Subject: [PATCH] [Fix] IndexDataTypeNormalizer not unwrapping float casting --- src/tir/ir/data_type_rewriter.cc | 5 +- .../unittest/test_te_create_primfunc.py | 66 +++++++++++++++++++ 2 files changed, 70 insertions(+), 1 deletion(-) diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc index f0f0d84644fe4..8da7cfdd5b976 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tir/ir/data_type_rewriter.cc @@ -574,7 +574,10 @@ PrimExpr IndexDataTypeNormalizer::VisitExpr_(const VarNode* op) { } PrimExpr IndexDataTypeNormalizer::VisitExpr_(const CastNode* op) { - if (is_enabled_) { + // Unwrap the cast only when the dtype of this cast is integer dtype. + // When the dtype of this cast is not integer dtype, it means that this cast + // has some other purpose, and we should not unwrap the cast. + if (is_enabled_ && op->dtype.is_int()) { PrimExpr value = IndexDataTypeNormalizer::VisitExpr(op->value); return value->dtype == target_data_type_ ? value : Cast(target_data_type_, value); } diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py index f78dc458d9d3e..4b8d857e86192 100644 --- a/tests/python/unittest/test_te_create_primfunc.py +++ b/tests/python/unittest/test_te_create_primfunc.py @@ -689,6 +689,72 @@ def test_argmax(): tvm.ir.assert_structural_equal(prim_func, argmax_expected) +def te_resize2d_symbolic(): + oh = tir.Var("oh", "int64") + ow = tir.Var("ow", "int64") + roi = (0.0, 0.0, 0.0, 0.0) + A = te.placeholder((2, 3, 128, 128), "float32", name="A") + B = topi.image.resize2d( + A, + roi, + size=(oh, ow), + method="nearest_neighbor", + coordinate_transformation_mode="asymmetric", + rounding_method="round", + ) + return [A, B] + + +@T.prim_func +def tir_resize2d_symbolic( + A: T.Buffer[(T.int64(2), T.int64(3), T.int64(128), T.int64(128)), "float32"], + var_resize: T.handle, +): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + oh = T.var("int64") + ow = T.var("int64") + resize = T.match_buffer(var_resize, [T.int64(2), T.int64(3), oh, ow], dtype="float32") + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), oh, ow): + with T.block("resize"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(A[v_i0, v_i1, T.int64(0) : T.int64(128), T.int64(0) : T.int64(128)]) + T.writes(resize[v_i0, v_i1, v_i2, v_i3]) + resize[v_i0, v_i1, v_i2, v_i3] = A[ + v_i0, + v_i1, + T.max( + T.min( + T.Cast( + "int64", + T.round( + T.float32(128) / T.Cast("float32", oh) * T.Cast("float32", v_i2), + dtype="float32", + ), + ), + T.int64(127), + ), + T.int64(0), + ), + T.max( + T.min( + T.Cast( + "int64", + T.round( + T.float32(128) / T.Cast("float32", ow) * T.Cast("float32", v_i3), + dtype="float32", + ), + ), + T.int64(127), + ), + T.int64(0), + ), + ] + + +def test_resize2d_symbolic(): + _check_workload(te_resize2d_symbolic, tir_resize2d_symbolic, index_dtype_override="int64") + + def test_extern_with_explicit_buffer_access(): def te_extern(): A = te.placeholder((128, 128), name="A")