Skip to content

Commit

Permalink
[Arith] ExtendedEuclidean merge impl to int_operator (apache#5625)
Browse files Browse the repository at this point in the history
  • Loading branch information
ANSHUMAN TRIPATHY authored and trevor-m committed Jun 18, 2020
1 parent e331f5a commit 8b53c9b
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 162 deletions.
45 changes: 0 additions & 45 deletions include/tvm/arith/util.h

This file was deleted.

65 changes: 65 additions & 0 deletions src/arith/int_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#define TVM_ARITH_INT_OPERATOR_H_

#include <limits>
#include <utility>

namespace tvm {
namespace arith {
Expand Down Expand Up @@ -117,6 +118,70 @@ inline int64_t floormod(int64_t x, int64_t y) {
return is_floor_div ? rmod : rmod + y;
}

/*!
* \brief Use Extended Euclidean algorithm to solve ax + by = gcd(a, b)
* \param a The first coefficient.
* \param b The second coefficient.
* \param x The solution of x.
* \param y The solution of y.
* \return The GCD of a and b.
*/
inline int64_t ExtendedEuclidean(int64_t a, int64_t b, int64_t* x, int64_t* y) {
// Extended Euclidean algorithm
// if a < 0, the problem can be convert into
// |a|* (-x) + b * y = gcd(|a|, b)
//
// initial condition:
// a * 0 + b * 1 = b
// a * 1 + b * 0 = a
int64_t s = 0, old_s = 1;
int64_t r = b, old_r = a >= 0 ? a : -a;
// Iteration (r2 < r1):
// a * x1 + b * y1 = r1
// a * x2 + b * y2 = r2
// The above two eqs can derive the following eq (q = r1 / r2)
// a * (x1 - x2 * q) + b * (y1 - y2 * q) = r1 - r2 * q = r3
// Because r3 < r2, the iteration can eventually terminate
while (r != 0) {
int64_t q = old_r / r;
int64_t tmp = old_r;
old_r = r;
r = tmp - q * r;
tmp = old_s;
old_s = s;
s = tmp - q * s;
}

*x = a >= 0 ? old_s : -old_s;
if (b != 0) {
*y = (old_r - (*x) * a) / b;
} else {
*y = 1;
}

return old_r;
}

/*!
* \brief Take GCD of a and b.
* \param a The first operand.
* \param b The second operand.
* \return The result.
*/
inline int64_t ZeroAwareGCD(int64_t a, int64_t b) {
if (a < 0) a = -a;
if (b < 0) b = -b;
if (a < b) std::swap(a, b);
if (b == 0) return a;
// perform GCD (greatest common divisor)
// ax + by = gcd(a, b) z if a != 0, b != 0
while (a % b != 0) {
a = a % b;
std::swap(a, b);
}
return b;
}

} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_INT_OPERATOR_H_
61 changes: 0 additions & 61 deletions src/arith/modular_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,49 +270,7 @@ class ModularSetAnalyzer::Impl : public ExprFunctor<ModularSetAnalyzer::Entry(co
return Entry(ZeroAwareGCD(ZeroAwareGCD(base0, base1), coeff), base0);
}
}
/*!
* \brief Use Extended Euclidean algorithm to solve ax + by = gcd(a, b)
* \param a The first coefficient.
* \param b The second coefficient.
* \param x The solution of x.
* \param y The solution of y.
* \return The GCD of a and b.
*/
static int64_t ExtendedEuclidean(int64_t a, int64_t b, int64_t* x, int64_t* y) {
// Extended Euclidean algorithm
// if a < 0, the problem can be convert into
// |a|* (-x) + b * y = gcd(|a|, b)
//
// initial condition:
// a * 0 + b * 1 = b
// a * 1 + b * 0 = a
int64_t s = 0, old_s = 1;
int64_t r = b, old_r = a >= 0 ? a : -a;
// Iteration (r2 < r1):
// a * x1 + b * y1 = r1
// a * x2 + b * y2 = r2
// The above two eqs can derive the following eq (q = r1 / r2)
// a * (x1 - x2 * q) + b * (y1 - y2 * q) = r1 - r2 * q = r3
// Because r3 < r2, the iteration can eventually terminate
while (r != 0) {
int64_t q = old_r / r;
int64_t tmp = old_r;
old_r = r;
r = tmp - q * r;
tmp = old_s;
old_s = s;
s = tmp - q * s;
}

*x = a >= 0 ? old_s : -old_s;
if (b != 0) {
*y = (old_r - (*x) * a) / b;
} else {
*y = 1;
}

return old_r;
}
/*!
* \brief Create interect of two sets.
* \param a The left operand.
Expand All @@ -339,25 +297,6 @@ class ModularSetAnalyzer::Impl : public ExprFunctor<ModularSetAnalyzer::Entry(co
return Nothing();
}
}
/*!
* \brief Take GCD of a and b.
* \param a The first operand.
* \param b The second operand.
* \return The result.
*/
static int64_t ZeroAwareGCD(int64_t a, int64_t b) {
if (a < 0) a = -a;
if (b < 0) b = -b;
if (a < b) std::swap(a, b);
if (b == 0) return a;
// perform GCD (greatest common divisor)
// ax + by = gcd(a, b) z if a != 0, b != 0
while (a % b != 0) {
a = a % b;
std::swap(a, b);
}
return b;
}
/*!
* \brief return everything dtype can represent.
* \return Bound that represent everything dtype can represent.
Expand Down
7 changes: 4 additions & 3 deletions src/arith/solve_linear_equation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,14 @@
#include <tvm/arith/analyzer.h>
#include <tvm/arith/int_solver.h>
#include <tvm/arith/pattern.h>
#include <tvm/arith/util.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>

#include "int_operator.h"

namespace tvm {
namespace arith {

Expand Down Expand Up @@ -96,7 +97,7 @@ void SmithNormalFormDiag(std::vector<std::vector<int64_t>>* S, std::vector<std::
int64_t g, a, b;
// g = a*matrix[index][index] + b*matrix[i][index]
if ((*S)[i][index] % (*S)[index][index] != 0) {
std::tie(g, a, b) = xgcd((*S)[index][index], (*S)[i][index]);
g = ExtendedEuclidean((*S)[index][index], (*S)[i][index], &a, &b);
} else {
// Explicitly avoid changing the index-th row. This is important to avoid infinite loop.
g = (*S)[index][index];
Expand Down Expand Up @@ -149,7 +150,7 @@ void SmithNormalFormDiag(std::vector<std::vector<int64_t>>* S, std::vector<std::
int64_t g, a, b;
// g = a*matrix[index][index] + b*matrix[index][j]
if ((*S)[index][j] % (*S)[index][index] != 0) {
std::tie(g, a, b) = xgcd((*S)[index][index], (*S)[index][j]);
g = ExtendedEuclidean((*S)[index][index], (*S)[index][j], &a, &b);
// During this phase we may disrupt the zeroness of the index-th column, so we will
// have to take some action if this might have happened.
changed = true;
Expand Down
53 changes: 0 additions & 53 deletions src/arith/util.cc

This file was deleted.

0 comments on commit 8b53c9b

Please sign in to comment.