Skip to content

Commit

Permalink
[Fix] IndexDataTypeNormalizer not unwrapping float casting
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterJH5574 committed Jan 15, 2023
1 parent 84a9f8c commit bf9ddb3
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/tir/ir/data_type_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,10 @@ PrimExpr IndexDataTypeNormalizer::VisitExpr_(const VarNode* op) {
}

PrimExpr IndexDataTypeNormalizer::VisitExpr_(const CastNode* op) {
if (is_enabled_) {
// Unwrap the cast only when the dtype of this cast is integer dtype.
// When the dtype of this cast is not integer dtype, it means that this cast
// has some other purpose, and we should not unwrap the cast.
if (is_enabled_ && op->dtype.is_int()) {
PrimExpr value = IndexDataTypeNormalizer::VisitExpr(op->value);
return value->dtype == target_data_type_ ? value : Cast(target_data_type_, value);
}
Expand Down
66 changes: 66 additions & 0 deletions tests/python/unittest/test_te_create_primfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,72 @@ def test_argmax():
tvm.ir.assert_structural_equal(prim_func, argmax_expected)


def te_resize2d_symbolic():
oh = tir.Var("oh", "int64")
ow = tir.Var("ow", "int64")
roi = (0.0, 0.0, 0.0, 0.0)
A = te.placeholder((2, 3, 128, 128), "float32", name="A")
B = topi.image.resize2d(
A,
roi,
size=(oh, ow),
method="nearest_neighbor",
coordinate_transformation_mode="asymmetric",
rounding_method="round",
)
return [A, B]


@T.prim_func
def tir_resize2d_symbolic(
A: T.Buffer[(T.int64(2), T.int64(3), T.int64(128), T.int64(128)), "float32"],
var_resize: T.handle,
):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
oh = T.var("int64")
ow = T.var("int64")
resize = T.match_buffer(var_resize, [T.int64(2), T.int64(3), oh, ow], dtype="float32")
for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), oh, ow):
with T.block("resize"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(A[v_i0, v_i1, T.int64(0) : T.int64(128), T.int64(0) : T.int64(128)])
T.writes(resize[v_i0, v_i1, v_i2, v_i3])
resize[v_i0, v_i1, v_i2, v_i3] = A[
v_i0,
v_i1,
T.max(
T.min(
T.Cast(
"int64",
T.round(
T.float32(128) / T.Cast("float32", oh) * T.Cast("float32", v_i2),
dtype="float32",
),
),
T.int64(127),
),
T.int64(0),
),
T.max(
T.min(
T.Cast(
"int64",
T.round(
T.float32(128) / T.Cast("float32", ow) * T.Cast("float32", v_i3),
dtype="float32",
),
),
T.int64(127),
),
T.int64(0),
),
]


def test_resize2d_symbolic():
_check_workload(te_resize2d_symbolic, tir_resize2d_symbolic, index_dtype_override="int64")


def test_extern_with_explicit_buffer_access():
def te_extern():
A = te.placeholder((128, 128), name="A")
Expand Down

0 comments on commit bf9ddb3

Please sign in to comment.