diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 96cb92850d5a..cd482279efe0 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -515,7 +515,6 @@ class IterMapRewriter : public ExprMutator { */ Optional TryFuseIters(IterSumExpr expr) { if (!is_zero(expr->base)) return NullOpt; - if (expr->args.size() == 1) return expr->args[0]; // select the iterators in order std::vector visited(expr->args.size(), false); std::vector flattened_iters, grouped_iters; diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index b34acb9ae359..c307034c04c9 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -285,6 +285,15 @@ def test_predicate(): ) assert len(res) == 0 + # zero iter + xo = tvm.tir.Var("xo", "int32"), 1 + xi = tvm.tir.Var("xi", "int32"), 129 + y = tvm.tir.Var("y", "int32"), 128 + + res = tvm.arith.detect_iter_map( + [xo[0] * 129 + xi[0], y[0]], var_dom([xo, xi, y]), xo[0] * 129 + xi[0] < 128 + ) + def convert_division(divisions): if divisions is None or len(divisions) == 0: