From 4147e4a0e33a51297a4f8bb8d8586535f7b10cc1 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 16 Mar 2023 02:41:47 +0000 Subject: [PATCH] [Fix][TIR] Fix tvm::arith::UnionLowerBound --- src/arith/int_set.cc | 1 + tests/python/unittest/test_arith_intset.py | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 7d601d9a8baec..a75d316a7ece4 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -865,6 +865,7 @@ IntSet UnionLowerBound(const Array& sets) { PrimExpr min_inclusive{nullptr}; PrimExpr max_inclusive(nullptr); for (const IntSet& int_set : sets) { + if (int_set.IsNothing()) continue; if (const auto* interval_set = int_set.as()) { PrimExpr new_min_inclusive = interval_set->min_value; PrimExpr new_max_inclusive = interval_set->max_value; diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index da3fd94f8192b..12214c596ce70 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -373,6 +373,10 @@ def test_union_lower_bound(): result = tvm.arith.int_set.union_lower_bound([set_0, set_1]) assert result.min_value.same_as(neg_inf) assert result.max_value.same_as(pos_inf) + set_2 = tvm.arith.IntervalSet(min_value=pos_inf, max_value=neg_inf) + result = tvm.arith.int_set.union_lower_bound([set_0, set_2]) + assert result.min_value.same_as(neg_inf) + assert result.max_value.same_as(0) if __name__ == "__main__":