diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 7d601d9a8baec..a75d316a7ece4 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -865,6 +865,7 @@ IntSet UnionLowerBound(const Array& sets) { PrimExpr min_inclusive{nullptr}; PrimExpr max_inclusive(nullptr); for (const IntSet& int_set : sets) { + if (int_set.IsNothing()) continue; if (const auto* interval_set = int_set.as()) { PrimExpr new_min_inclusive = interval_set->min_value; PrimExpr new_max_inclusive = interval_set->max_value; diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index da3fd94f8192b..a6c7390782a9f 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -370,7 +370,8 @@ def test_union_lower_bound(): pos_inf = tvm.arith.int_set.pos_inf() set_0 = tvm.arith.IntervalSet(min_value=neg_inf, max_value=0) set_1 = tvm.arith.IntervalSet(min_value=1, max_value=pos_inf) - result = tvm.arith.int_set.union_lower_bound([set_0, set_1]) + set_2 = tvm.arith.IntervalSet(min_value=pos_inf, max_value=neg_inf) + result = tvm.arith.int_set.union_lower_bound([set_0, set_1, set_2]) assert result.min_value.same_as(neg_inf) assert result.max_value.same_as(pos_inf)