diff --git a/python/tvm/arith/__init__.py b/python/tvm/arith/__init__.py new file mode 100644 index 0000000000000..40e977e61d75b --- /dev/null +++ b/python/tvm/arith/__init__.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Integer bound analysis, simplification and pattern detection.""" + +from .int_set import IntSet, IntervalSet +from .analyzer import ModularSet, ConstIntBound, Analyzer +from .bound import deduce_bound +from .pattern import detect_linear_equation, detect_clip_bound diff --git a/python/tvm/arith/_ffi_api.py b/python/tvm/arith/_ffi_api.py new file mode 100644 index 0000000000000..c551e56515631 --- /dev/null +++ b/python/tvm/arith/_ffi_api.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.arith""" +import tvm._ffi + + +tvm._ffi._init_api("arith", __name__) diff --git a/python/tvm/arith.py b/python/tvm/arith/analyzer.py similarity index 83% rename from python/tvm/arith.py rename to python/tvm/arith/analyzer.py index b67e99c204baa..382a7e033e753 100644 --- a/python/tvm/arith.py +++ b/python/tvm/arith/analyzer.py @@ -17,34 +17,7 @@ """Arithmetic data structure and utility""" import tvm._ffi from tvm.runtime import Object - - -class IntSet(Object): - """Represent a set of integer in one dimension.""" - def is_nothing(self): - """Whether the set represent nothing""" - return _IntSetIsNothing(self) - - def is_everything(self): - """Whether the set represent everything""" - return _IntSetIsEverything(self) - - -@tvm._ffi.register_object("arith.IntervalSet") -class IntervalSet(IntSet): - """Represent set of continuous interval [min_value, max_value] - - Parameters - ---------- - min_value : Expr - The minimum value in the interval. - - max_value : Expr - The maximum value in the interval. - """ - def __init__(self, min_value, max_value): - self.__init_handle_by_constructor__( - _make_IntervalSet, min_value, max_value) +from . import _ffi_api @tvm._ffi.register_object("arith.ModularSet") @@ -52,7 +25,7 @@ class ModularSet(Object): """Represent range of (coeff * x + base) for x in Z """ def __init__(self, coeff, base): self.__init_handle_by_constructor__( - _make_ModularSet, coeff, base) + _ffi_api.ModularSet, coeff, base) @tvm._ffi.register_object("arith.ConstIntBound") @@ -72,7 +45,7 @@ class ConstIntBound(Object): def __init__(self, min_value, max_value): self.__init_handle_by_constructor__( - _make_ConstIntBound, min_value, max_value) + _ffi_api.ConstIntBound, min_value, max_value) class ConstraintScope: @@ -105,11 +78,12 @@ class Analyzer: be used to perform various symbolic integer analysis. """ def __init__(self): - _mod = _CreateAnalyzer() + _mod = _ffi_api.CreateAnalyzer() self._const_int_bound = _mod("const_int_bound") self._const_int_bound_update = _mod("const_int_bound_update") self._bind = _mod("bind") self._modular_set = _mod("modular_set") + self._simplify = _mod("Simplify") self._rewrite_simplify = _mod("rewrite_simplify") self._canonical_simplify = _mod("canonical_simplify") self._int_set = _mod("int_set") @@ -120,7 +94,7 @@ def const_int_bound(self, expr): Parameters ---------- - expr : tvm.Expr + expr : PrimExpr The expression. Returns @@ -135,7 +109,7 @@ def modular_set(self, expr): Parameters ---------- - expr : tvm.Expr + expr : PrimExpr The expression. Returns @@ -145,12 +119,27 @@ def modular_set(self, expr): """ return self._modular_set(expr) + def simplify(self, expr): + """Simplify expression via both rewrite and canonicalization. + + Parameters + ---------- + expr : PrimExpr + The expression. + + Returns + ------- + result : Expr + The result. + """ + return self._simplify(expr) + def rewrite_simplify(self, expr): """Simplify expression via rewriting rules. Parameters ---------- - expr : tvm.Expr + expr : PrimExpr The expression. Returns @@ -165,7 +154,7 @@ def canonical_simplify(self, expr): Parameters ---------- - expr : tvm.Expr + expr : PrimExpr The expression. Returns @@ -180,7 +169,7 @@ def int_set(self, expr, dom_map): Parameters ---------- - expr : tvm.Expr + expr : PrimExpr The expression. dom_map : Dict[Var, tvm.arith.IntSet] @@ -198,10 +187,10 @@ def bind(self, var, expr): Parameters ---------- - var : tvm.Var + var : tvm.tir.Var The variable. - expr : tvm.Expr + expr : PrimExpr The expression. """ return self._bind(var, expr) @@ -211,7 +200,7 @@ def constraint_scope(self, constraint): Parameters ---------- - constraint : tvm.Expr + constraint : PrimExpr The constraint expression. returns @@ -240,7 +229,7 @@ def update(self, var, info, override=False): Parameters ---------- - var : tvm.Var + var : tvm.tir.Var The variable. info : tvm.Object @@ -254,6 +243,3 @@ def update(self, var, info, override=False): else: raise TypeError( "Do not know how to handle type {}".format(type(info))) - - -tvm._ffi._init_api("tvm.arith") diff --git a/python/tvm/arith/bound.py b/python/tvm/arith/bound.py new file mode 100644 index 0000000000000..6f4b220a378ed --- /dev/null +++ b/python/tvm/arith/bound.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Bound deduction.""" +from . import _ffi_api + + +def deduce_bound(var, cond, hint_map, relax_map): + """Deduce the bound of the target variable in the cond. + + Parameters + ---------- + var : Var + The target variable to be deduced. + + cond : PrimExpr + The condition + + hint_map : Map[Var, IntSet] + Domain of variables used to help deduction. + + relax_map : Map[Var, IntSet] + The fomain of the variables to be relaxed + using the provided domain. + """ + return _ffi_api.DeduceBound(var, cond, hint_map, relax_map) diff --git a/python/tvm/arith/int_set.py b/python/tvm/arith/int_set.py new file mode 100644 index 0000000000000..838e8e5227ca5 --- /dev/null +++ b/python/tvm/arith/int_set.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Integer set.""" +import tvm._ffi +from tvm.runtime import Object +from . import _ffi_api + + +class IntSet(Object): + """Represent a set of integer in one dimension.""" + def is_nothing(self): + """Whether the set represent nothing""" + return _ffi_api.IntSetIsNothing(self) + + def is_everything(self): + """Whether the set represent everything""" + return _ffi_api.IntSetIsEverything(self) + + @staticmethod + def vector(vec): + """Construct an integer set that covers the vector expr + + Parameters + ---------- + vec : PrimExpr + The vector expression. + + Returns + ------- + rset : IntSet + The result set. + """ + return _ffi_api.intset_vector(vec) + + @staticmethod + def single_point(point): + """Construct a point set. + + Parameters + ---------- + point : PrimExpr + The vector expression. + + Returns + ------- + rset : IntSet + The result set. + """ + return _ffi_api.intset_single_point(point) + + +@tvm._ffi.register_object("arith.IntervalSet") +class IntervalSet(IntSet): + """Represent set of continuous interval [min_value, max_value] + + Parameters + ---------- + min_value : PrimExpr + The minimum value in the interval. + + max_value : PrimExpr + The maximum value in the interval. + """ + def __init__(self, min_value, max_value): + self.__init_handle_by_constructor__( + _ffi_api.IntervalSet, min_value, max_value) diff --git a/python/tvm/arith/pattern.py b/python/tvm/arith/pattern.py new file mode 100644 index 0000000000000..22810882701ef --- /dev/null +++ b/python/tvm/arith/pattern.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Detect common patterns.""" +from . import _ffi_api + + +def detect_linear_equation(expr, var_list): + """Match `expr = sum_{i=0}^{n-1} var[i] * coeff[i] + coeff[n]` + + Where coeff[i] and base are invariant of var[j] for all i and j. + + Parameters + ---------- + expr : PrimExpr + The expression to be matched. + + var_list : List[tvm.tir.Var] + A list of variables. + + Returns + ------- + coeff : List[PrimExpr] + A list of co-efficients if the match is successful. + An empty list if the match failed. + """ + return _ffi_api.DetectLinearEquation(expr, var_list) + + +def detect_clip_bound(expr, var_list): + """ Detect if expression corresponds to clip bound of the vars + + Parameters + ---------- + expr : PrimExpr + The expression to be matched. + + var_list : List[tvm.tir.Var] + A list of variables. + + Returns + ------- + coeff : List[PrimExpr] + `concat([min_value[i], max_value[i]] for i, v in enumerate(var_list))` + An empty list if the match failed. + """ + return _ffi_api.DetectClipBound(expr, var_list) diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc index f996bdbfcbbe8..3942f6ef0f202 100644 --- a/src/api/api_arith.cc +++ b/src/api/api_arith.cc @@ -64,33 +64,33 @@ TVM_REGISTER_GLOBAL("arith.DeduceBound") TVM_REGISTER_GLOBAL("arith.DomainTouched") .set_body_typed(DomainTouched); -TVM_REGISTER_GLOBAL("arith._IntervalSetGetMin") +TVM_REGISTER_GLOBAL("arith.IntervalSetGetMin") .set_body_method(&IntSet::min); -TVM_REGISTER_GLOBAL("arith._IntervalSetGetMax") +TVM_REGISTER_GLOBAL("arith.IntervalSetGetMax") .set_body_method(&IntSet::max); -TVM_REGISTER_GLOBAL("arith._IntSetIsNothing") +TVM_REGISTER_GLOBAL("arith.IntSetIsNothing") .set_body_method(&IntSet::is_nothing); -TVM_REGISTER_GLOBAL("arith._IntSetIsEverything") +TVM_REGISTER_GLOBAL("arith.IntSetIsEverything") .set_body_method(&IntSet::is_everything); ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) { return ConstIntBound(min_value, max_value); } -TVM_REGISTER_GLOBAL("arith._make_ConstIntBound") +TVM_REGISTER_GLOBAL("arith.ConstIntBound") .set_body_typed(MakeConstIntBound); ModularSet MakeModularSet(int64_t coeff, int64_t base) { return ModularSet(coeff, base); } -TVM_REGISTER_GLOBAL("arith._make_ModularSet") +TVM_REGISTER_GLOBAL("arith.ModularSet") .set_body_typed(MakeModularSet); -TVM_REGISTER_GLOBAL("arith._CreateAnalyzer") +TVM_REGISTER_GLOBAL("arith.CreateAnalyzer") .set_body([](TVMArgs args, TVMRetValue* ret) { using runtime::PackedFunc; using runtime::TypedPackedFunc; @@ -108,6 +108,10 @@ TVM_REGISTER_GLOBAL("arith._CreateAnalyzer") return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { self->const_int_bound.Update(args[0], args[1], args[2]); }); + } else if (name == "Simplify") { + return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { + *ret = self->Simplify(args[0]); + }); } else if (name == "rewrite_simplify") { return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { *ret = self->rewrite_simplify(args[0]); diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 728cca1b5705e..adb38799fdf2c 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -54,7 +54,7 @@ IntervalSet MakeIntervalSet(PrimExpr min_value, PrimExpr max_value) { return IntervalSet(min_value, max_value); } -TVM_REGISTER_GLOBAL("arith._make_IntervalSet") +TVM_REGISTER_GLOBAL("arith.IntervalSet") .set_body_typed(MakeIntervalSet); diff --git a/tests/python/unittest/test_arith_deduce_bound.py b/tests/python/unittest/test_arith_deduce_bound.py index 787dfe80d536d..5e08635cd53fd 100644 --- a/tests/python/unittest/test_arith_deduce_bound.py +++ b/tests/python/unittest/test_arith_deduce_bound.py @@ -38,90 +38,90 @@ def test_deduce(): fdiv = tvm.floordiv e0 = (-b)*a+c-d - res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) + res0 = tvm.arith.deduce_bound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) ans0 = fdiv(d - c, b*-1) assert_expr_equal(res0.max_value, ans0) # expression containing variable a is on rhs - res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {}) + res0 = tvm.arith.deduce_bound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {}) assert_expr_equal(res0.max_value, ans0) e0 = d*a+c-d - res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) + res0 = tvm.arith.deduce_bound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) ans0 = fdiv(d-c, d) assert_expr_equal(res0.max_value, ans0) # expression containing variable a is on rhs - res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {}) + res0 = tvm.arith.deduce_bound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {}) assert_expr_equal(res0.max_value, ans0) e1 = (a*4+b < c) - res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) + res1 = tvm.arith.deduce_bound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) ans1 = fdiv(c-1-b, 4) assert_expr_equal(res1.max_value, ans1) # expression containing variable a is on rhs e1 = (c > a*4+b) - res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) + res1 = tvm.arith.deduce_bound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) assert_expr_equal(res1.max_value, ans1) e2 = (tvm.max(5, a * 4) < 0) - res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) + res2 = tvm.arith.deduce_bound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) assert str(res2.max_value) == "neg_inf" assert str(res2.min_value) == "pos_inf" # expression containing variable a is on rhs e2 = (zero < tvm.max(5, a * 4)) - res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) + res2 = tvm.arith.deduce_bound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) assert str(res2.max_value) == "neg_inf" assert str(res2.min_value) == "pos_inf" e3 = (-b)+a*c-d - res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) + res3 = tvm.arith.deduce_bound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) ans3 = fdiv(2,c)+1 assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3) - res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) + res3 = tvm.arith.deduce_bound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3) # tests for `EQ` op - res4 = tvm.arith.DeduceBound(a, a == b, {}, {}) + res4 = tvm.arith.deduce_bound(a, a == b, {}, {}) assert_expr_equal(res4.max_value, b) assert_expr_equal(res4.min_value, b) # Unsatisfiable `EQ`, variable as one of the Operand - res5 = tvm.arith.DeduceBound(a, (a == b), {b: b_s}, {b: b_s}) + res5 = tvm.arith.deduce_bound(a, (a == b), {b: b_s}, {b: b_s}) assert str(res5.max_value) == "neg_inf" assert str(res5.min_value) == "pos_inf" # variable `a` on the RHS side - res6 = tvm.arith.DeduceBound(a, 10 == a, {}, {}) + res6 = tvm.arith.deduce_bound(a, 10 == a, {}, {}) assert_expr_equal(res6.max_value, 10) assert_expr_equal(res6.min_value, 10) # Add, Sub in `EQ` e4 = ((a - c) == (b + d)) ans4 = (b + d + c) - res7 = tvm.arith.DeduceBound(a, e4, {b: b_s, c: c_s, d: d_s}, {}) + res7 = tvm.arith.deduce_bound(a, e4, {b: b_s, c: c_s, d: d_s}, {}) assert_expr_equal(res7.max_value, ans4) assert_expr_equal(res7.min_value, ans4) # Satisfiable Mul in `EQ` with negative sign - res8 = tvm.arith.DeduceBound(a, (5 * a == -10), {}, {}) + res8 = tvm.arith.deduce_bound(a, (5 * a == -10), {}, {}) assert_expr_equal(res8.max_value, -2) assert_expr_equal(res8.min_value, -2) # Unsatisfiable Mul in `EQ` e5 = (4 * a == b) - res9 = tvm.arith.DeduceBound(a, e5, {b: b_s}, {}) + res9 = tvm.arith.deduce_bound(a, e5, {b: b_s}, {}) assert str(res9.max_value) == "neg_inf" assert str(res9.min_value) == "pos_inf" # Unsatisfiable Mul in `EQ` - res10 = tvm.arith.DeduceBound(a, (b * a == b), {b: b_s}, {}) # simplifier is not able to prove that (b % b == 0) + res10 = tvm.arith.deduce_bound(a, (b * a == b), {b: b_s}, {}) # simplifier is not able to prove that (b % b == 0) assert str(res10.max_value) == "neg_inf" assert str(res10.min_value) == "pos_inf" @@ -137,15 +137,15 @@ def test_check(): d_s = tvm.arith.IntervalSet(-3, -1) # no compare operator - res1 = tvm.arith.DeduceBound(a, a+b, {b: b_s}, {}) + res1 = tvm.arith.deduce_bound(a, a+b, {b: b_s}, {}) assert res1.is_nothing() # multiple compare operators - res2 = tvm.arith.DeduceBound(a, (a+b>3).astype(c.dtype)>c , {b: b_s, c: c_s}, {}) + res2 = tvm.arith.deduce_bound(a, (a+b>3).astype(c.dtype)>c , {b: b_s, c: c_s}, {}) assert res2.is_nothing() # multiple target variable - res2 = tvm.arith.DeduceBound(a, a*2-a>b, {b: b_s}, {}) + res2 = tvm.arith.deduce_bound(a, a*2-a>b, {b: b_s}, {}) assert res2.is_nothing() def test_deduce_basic(): @@ -155,21 +155,21 @@ def test_basic(a1, a2, coff): b_s = tvm.arith.IntervalSet(a1, a2) e0 = b + a*coff + 3 - res1 = tvm.arith.DeduceBound(a, e0<17, {b: b_s}, {b: b_s}) + res1 = tvm.arith.deduce_bound(a, e0<17, {b: b_s}, {b: b_s}) [x, y] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value] assert (tvm.ir_pass.Simplify((x * coff + 3 + y) < 17)).value == 1 # expression containing variable a is on rhs - res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32") < e0, {b: b_s}, {b: b_s}) + res1 = tvm.arith.deduce_bound(a, tvm.const(17, "int32") < e0, {b: b_s}, {b: b_s}) [x, y] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value] assert (tvm.ir_pass.Simplify((x * coff + 3 + y) > 17)).value == 1 # expression containing variable a is on rhs - res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32")>= e0, {b: b_s}, {b: b_s}) + res1 = tvm.arith.deduce_bound(a, tvm.const(17, "int32")>= e0, {b: b_s}, {b: b_s}) [x, y] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value] assert (tvm.ir_pass.Simplify((x * coff + 3 + y) <= 17)).value == 1 - res1 = tvm.arith.DeduceBound(a, e0>=17, {b: b_s}, {b: b_s}) + res1 = tvm.arith.deduce_bound(a, e0>=17, {b: b_s}, {b: b_s}) [x, y] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value] assert (tvm.ir_pass.Simplify((x * coff + 3 + y) >= 17)).value == 1 @@ -187,21 +187,21 @@ def test_complex(a1, a2, coff): b_s = tvm.arith.IntervalSet(a1, a2) e0 = (b*3 + a* coff) * 4 - res1 = tvm.arith.DeduceBound(a, e0<63, {b: b_s}, {b: b_s}) + res1 = tvm.arith.deduce_bound(a, e0<63, {b: b_s}, {b: b_s}) [t, x] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value] assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) < 63)).value == 1 # expression containing variable a is on rhs - res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32")>= e0, {b: b_s}, {b: b_s}) + res1 = tvm.arith.deduce_bound(a, tvm.const(63, "int32")>= e0, {b: b_s}, {b: b_s}) [t, x] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value] assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) <= 63)).value == 1 - res1 = tvm.arith.DeduceBound(a, e0>63, {b: b_s}, {b: b_s}) + res1 = tvm.arith.deduce_bound(a, e0>63, {b: b_s}, {b: b_s}) [t, x] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value] assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) > 63)).value == 1 # expression containing variable a is on rhs - res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32") <= e0, {b: b_s}, {b: b_s}) + res1 = tvm.arith.deduce_bound(a, tvm.const(63, "int32") <= e0, {b: b_s}, {b: b_s}) [t, x] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value] assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) >= 63)).value == 1 diff --git a/tests/python/unittest/test_arith_detect_clip_bound.py b/tests/python/unittest/test_arith_detect_clip_bound.py index 3301c24049aef..44ae24cb6815c 100644 --- a/tests/python/unittest/test_arith_detect_clip_bound.py +++ b/tests/python/unittest/test_arith_detect_clip_bound.py @@ -20,14 +20,14 @@ def test_basic(): a = tvm.var("a") b = tvm.var("b") c = tvm.var("c") - m = tvm.arith.DetectClipBound(tvm.all(a * 1 < b * 6, + m = tvm.arith.detect_clip_bound(tvm.all(a * 1 < b * 6, a - 1 > 0), [a]) assert tvm.ir_pass.Simplify(m[1] - (b * 6 - 1)).value == 0 assert m[0].value == 2 - m = tvm.arith.DetectClipBound(tvm.all(a * 1 < b * 6, + m = tvm.arith.detect_clip_bound(tvm.all(a * 1 < b * 6, a - 1 > 0), [a, b]) assert len(m) == 0 - m = tvm.arith.DetectClipBound(tvm.all(a + 10 * c <= 20, + m = tvm.arith.detect_clip_bound(tvm.all(a + 10 * c <= 20, b - 1 > 0), [a, b]) assert tvm.ir_pass.Simplify(m[1] - (20 - 10 * c)).value == 0 assert tvm.ir_pass.Simplify(m[2] - 2).value == 0 diff --git a/tests/python/unittest/test_arith_detect_linear_equation.py b/tests/python/unittest/test_arith_detect_linear_equation.py index cacb62456b794..3b103026aec3d 100644 --- a/tests/python/unittest/test_arith_detect_linear_equation.py +++ b/tests/python/unittest/test_arith_detect_linear_equation.py @@ -19,50 +19,50 @@ def test_basic(): a = tvm.var("a") b = tvm.var("b") - m = tvm.arith.DetectLinearEquation(a * 4 + b * 6 + 7, [a]) + m = tvm.arith.detect_linear_equation(a * 4 + b * 6 + 7, [a]) assert m[0].value == 4 assert tvm.ir_pass.Simplify(m[1] - (b * 6 + 7)).value == 0 - m = tvm.arith.DetectLinearEquation(a * 4 * (a+1) + b * 6 + 7, [a]) + m = tvm.arith.detect_linear_equation(a * 4 * (a+1) + b * 6 + 7, [a]) assert len(m) == 0 - m = tvm.arith.DetectLinearEquation(a * 4 + (a+1) + b * 6 + 7, [a]) + m = tvm.arith.detect_linear_equation(a * 4 + (a+1) + b * 6 + 7, [a]) assert m[0].value == 5 assert tvm.ir_pass.Simplify(m[1] - (b * 6 + 7 + 1)).value == 0 - m = tvm.arith.DetectLinearEquation(a * b + 7, [a]) + m = tvm.arith.detect_linear_equation(a * b + 7, [a]) assert m[0] == b - m = tvm.arith.DetectLinearEquation(b * 7, [a]) + m = tvm.arith.detect_linear_equation(b * 7, [a]) assert m[0].value == 0 - m = tvm.arith.DetectLinearEquation(b * 7, []) + m = tvm.arith.detect_linear_equation(b * 7, []) assert len(m) == 1 assert tvm.ir_pass.Simplify(m[0] - b * 7).value == 0 def test_multivariate(): v = [tvm.var("v%d" % i) for i in range(4)] b = tvm.var("b") - m = tvm.arith.DetectLinearEquation(v[0] * (b + 4) + v[0] + v[1] * 8, v) + m = tvm.arith.detect_linear_equation(v[0] * (b + 4) + v[0] + v[1] * 8, v) assert(tvm.ir_pass.Equal(tvm.ir_pass.Simplify(m[0]), b + 5)) assert(m[1].value == 8) - m = tvm.arith.DetectLinearEquation(v[0] * (b + 4) + v[0] + v[1] * 8 * v[2], v) + m = tvm.arith.detect_linear_equation(v[0] * (b + 4) + v[0] + v[1] * 8 * v[2], v) assert(len(m) == 0) - m = tvm.arith.DetectLinearEquation(v[0] * (b + 4) + v[0] + v[1] * 8 * v[1] + v[3], v) + m = tvm.arith.detect_linear_equation(v[0] * (b + 4) + v[0] + v[1] * 8 * v[1] + v[3], v) assert(len(m) == 0) - m = tvm.arith.DetectLinearEquation(((v[0] * b + v[1]) * 8 + v[2] + 1) * 2, v) + m = tvm.arith.detect_linear_equation(((v[0] * b + v[1]) * 8 + v[2] + 1) * 2, v) assert(m[1].value == 16) assert(m[2].value == 2) assert(m[len(m)-1].value == 2) - m = tvm.arith.DetectLinearEquation((v[0] - v[1]), [v[2]]) + m = tvm.arith.detect_linear_equation((v[0] - v[1]), [v[2]]) assert(m[0].value == 0) assert(tvm.ir_pass.Simplify(m[1] - (v[0] - v[1])).value == 0) - m = tvm.arith.DetectLinearEquation((v[0] - v[1]), []) + m = tvm.arith.detect_linear_equation((v[0] - v[1]), []) assert(len(m) == 1) assert(tvm.ir_pass.Simplify(m[0] - (v[0] - v[1])).value == 0) diff --git a/tests/python/unittest/test_arith_domain_touched.py b/tests/python/unittest/test_arith_domain_touched.py index 3e45d4e5fd93c..7876fb6c4d37a 100644 --- a/tests/python/unittest/test_arith_domain_touched.py +++ b/tests/python/unittest/test_arith_domain_touched.py @@ -35,19 +35,19 @@ def test_domain_touched(): ) ) ) - a_domain_r = tvm.arith.DomainTouched(ir, a, True, False) + a_domain_r = tvm.arith._ffi_api.DomainTouched(ir, a, True, False) assert a_domain_r[0].min.value == -1 assert a_domain_r[0].extent.value == 100 assert a_domain_r[1].min.value == -1 assert a_domain_r[1].extent.name == 'm' - a_domain_w = tvm.arith.DomainTouched(ir, a, False, True) + a_domain_w = tvm.arith._ffi_api.DomainTouched(ir, a, False, True) assert a_domain_w[0].min.value == 0 assert a_domain_w[0].extent.value == 100 assert a_domain_w[1].min.value == 0 assert a_domain_w[1].extent.name == 'm' - a_domain_rw= tvm.arith.DomainTouched(ir, a, True, True) + a_domain_rw= tvm.arith._ffi_api.DomainTouched(ir, a, True, True) assert a_domain_rw[0].min.value == -1 assert a_domain_rw[0].extent.value == 101 assert a_domain_rw[1].min.value == -1 @@ -55,17 +55,16 @@ def test_domain_touched(): assert a_domain_rw[1].extent.a.name == 'm' assert a_domain_rw[1].extent.b.value == 1 - b_domain_r = tvm.arith.DomainTouched(ir, b, True, False) + b_domain_r = tvm.arith._ffi_api.DomainTouched(ir, b, True, False) assert b_domain_r assert b_domain_r[0].min.value == -1 assert b_domain_r[0].extent.value == 100 assert b_domain_r[1].min.value == 1 assert b_domain_r[1].extent.name == 'm' - b_domain_w = tvm.arith.DomainTouched(ir, b, False, True) + b_domain_w = tvm.arith._ffi_api.DomainTouched(ir, b, False, True) assert isinstance(b_domain_w, tvm.container.Array) assert len(b_domain_w) == 0 if __name__ == "__main__": test_domain_touched() - diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index d83d33db5c1bb..dad2fa705b0f6 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -36,12 +36,16 @@ def test_basic(): assert s.min_value.value == 2 assert s.max_value.value == 3 + s = tvm.arith.IntSet.single_point(2) + assert s.min_value.value == 2 + assert s.max_value.value == 2 + def test_vector(): base = 10 stride = 3 lanes = 2 - s = tvm.arith.intset_vector(tvm.tir.Ramp(base, stride, lanes)) + s = tvm.arith.IntSet.vector(tvm.tir.Ramp(base, stride, lanes)) assert s.min_value.value == base assert s.max_value.value == base + stride * lanes - 1 diff --git a/vta/python/vta/ir_pass.py b/vta/python/vta/ir_pass.py index 8b8a2f06b4982..36d8e4198a40b 100644 --- a/vta/python/vta/ir_pass.py +++ b/vta/python/vta/ir_pass.py @@ -76,7 +76,7 @@ def _post_order(op): args = [] args += op.args[:base_args] for i in range(3): - m = tvm.arith.DetectLinearEquation( + m = tvm.arith.detect_linear_equation( op.args[i + base_args], [loop_var]) if not m: fail[0] = True @@ -867,25 +867,25 @@ def _flatten_loop(src_coeff, dst_coeff, extents): type(loop_body.value), str(loop_body.value), str(stmt))) # Derive array index coefficients - dst_coeff = tvm.arith.DetectLinearEquation(dst_idx, indices) + dst_coeff = tvm.arith.detect_linear_equation(dst_idx, indices) # Check if lhs/rhs is immediate use_imm = False imm_val = None if isinstance(rhs, tvm.tir.IntImm): assert lhs.buffer_var.same_as(dst_var) - src_coeff = tvm.arith.DetectLinearEquation(lhs.index, indices) + src_coeff = tvm.arith.detect_linear_equation(lhs.index, indices) use_imm = True imm_val = rhs if isinstance(lhs, tvm.tir.IntImm): assert rhs.buffer_var.same_as(dst_var) - src_coeff = tvm.arith.DetectLinearEquation(rhs.index, indices) + src_coeff = tvm.arith.detect_linear_equation(rhs.index, indices) use_imm = True imm_val = lhs if imm_val is None: imm_val = 0 assert lhs.buffer_var.same_as(dst_var) and rhs.buffer_var.same_as(dst_var) - src_lhs_coeff = tvm.arith.DetectLinearEquation(lhs.index, indices) - src_rhs_coeff = tvm.arith.DetectLinearEquation(rhs.index, indices) + src_lhs_coeff = tvm.arith.detect_linear_equation(lhs.index, indices) + src_rhs_coeff = tvm.arith.detect_linear_equation(rhs.index, indices) # Determine which side has the same coefficients lhs_equal = True rhs_equal = True