diff --git a/include/tvm/arith/util.h b/include/tvm/arith/util.h deleted file mode 100644 index adfcefcd2e21..000000000000 --- a/include/tvm/arith/util.h +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/arith/util.h - * \brief Utils for arithmetic analysis. - */ -#ifndef TVM_ARITH_UTIL_H_ -#define TVM_ARITH_UTIL_H_ - -#include -#include - -namespace tvm { -/*! \brief namespace of arithmetic analysis. */ -namespace arith { - -/*! - * \brief Calculate the extended greatest common divisor for two values. - * See https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm. - * \param a an integer number - * \param b an integer number - * \return 3 integers (div, m, n) where div = gcd(a, b) and a*m + b*n = div - */ -std::tuple xgcd(int64_t a, int64_t b); - -} // namespace arith -} // namespace tvm -#endif // TVM_ARITH_UTIL_H_ diff --git a/src/arith/int_operator.h b/src/arith/int_operator.h index 8e4dda0284e2..b69ce4fe5858 100644 --- a/src/arith/int_operator.h +++ b/src/arith/int_operator.h @@ -25,6 +25,7 @@ #define TVM_ARITH_INT_OPERATOR_H_ #include +#include namespace tvm { namespace arith { @@ -117,6 +118,70 @@ inline int64_t floormod(int64_t x, int64_t y) { return is_floor_div ? rmod : rmod + y; } +/*! + * \brief Use Extended Euclidean algorithm to solve ax + by = gcd(a, b) + * \param a The first coefficient. + * \param b The second coefficient. + * \param x The solution of x. + * \param y The solution of y. + * \return The GCD of a and b. + */ +inline int64_t ExtendedEuclidean(int64_t a, int64_t b, int64_t* x, int64_t* y) { + // Extended Euclidean algorithm + // if a < 0, the problem can be convert into + // |a|* (-x) + b * y = gcd(|a|, b) + // + // initial condition: + // a * 0 + b * 1 = b + // a * 1 + b * 0 = a + int64_t s = 0, old_s = 1; + int64_t r = b, old_r = a >= 0 ? a : -a; + // Iteration (r2 < r1): + // a * x1 + b * y1 = r1 + // a * x2 + b * y2 = r2 + // The above two eqs can derive the following eq (q = r1 / r2) + // a * (x1 - x2 * q) + b * (y1 - y2 * q) = r1 - r2 * q = r3 + // Because r3 < r2, the iteration can eventually terminate + while (r != 0) { + int64_t q = old_r / r; + int64_t tmp = old_r; + old_r = r; + r = tmp - q * r; + tmp = old_s; + old_s = s; + s = tmp - q * s; + } + + *x = a >= 0 ? old_s : -old_s; + if (b != 0) { + *y = (old_r - (*x) * a) / b; + } else { + *y = 1; + } + + return old_r; +} + +/*! + * \brief Take GCD of a and b. + * \param a The first operand. + * \param b The second operand. + * \return The result. + */ +inline int64_t ZeroAwareGCD(int64_t a, int64_t b) { + if (a < 0) a = -a; + if (b < 0) b = -b; + if (a < b) std::swap(a, b); + if (b == 0) return a; + // perform GCD (greatest common divisor) + // ax + by = gcd(a, b) z if a != 0, b != 0 + while (a % b != 0) { + a = a % b; + std::swap(a, b); + } + return b; +} + } // namespace arith } // namespace tvm #endif // TVM_ARITH_INT_OPERATOR_H_ diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc index 7ddb8f5251e7..2645fe910024 100644 --- a/src/arith/modular_set.cc +++ b/src/arith/modular_set.cc @@ -270,49 +270,7 @@ class ModularSetAnalyzer::Impl : public ExprFunctor= 0 ? a : -a; - // Iteration (r2 < r1): - // a * x1 + b * y1 = r1 - // a * x2 + b * y2 = r2 - // The above two eqs can derive the following eq (q = r1 / r2) - // a * (x1 - x2 * q) + b * (y1 - y2 * q) = r1 - r2 * q = r3 - // Because r3 < r2, the iteration can eventually terminate - while (r != 0) { - int64_t q = old_r / r; - int64_t tmp = old_r; - old_r = r; - r = tmp - q * r; - tmp = old_s; - old_s = s; - s = tmp - q * s; - } - - *x = a >= 0 ? old_s : -old_s; - if (b != 0) { - *y = (old_r - (*x) * a) / b; - } else { - *y = 1; - } - return old_r; - } /*! * \brief Create interect of two sets. * \param a The left operand. @@ -339,25 +297,6 @@ class ModularSetAnalyzer::Impl : public ExprFunctor #include #include -#include #include #include #include #include #include +#include "int_operator.h" + namespace tvm { namespace arith { @@ -96,7 +97,7 @@ void SmithNormalFormDiag(std::vector>* S, std::vector>* S, std::vector -#include - -namespace tvm { -namespace arith { - -std::tuple xgcd(int64_t a, int64_t b) { - int64_t s = 0, old_s = 1; - int64_t t = 1, old_t = 0; - int64_t r = b, old_r = a; - - while (r != 0) { - int64_t q = old_r / r; - std::swap(r, old_r); - r -= q * old_r; - std::swap(s, old_s); - s -= q * old_s; - std::swap(t, old_t); - t -= q * old_t; - } - - CHECK_EQ(a % old_r, 0); - CHECK_EQ(b % old_r, 0); - CHECK(old_r == old_s * a + old_t * b); - - return std::make_tuple(old_r, old_s, old_t); -} - -} // namespace arith -} // namespace tvm