Skip to content

Commit

Permalink
add constraint check to the constructor of modular set entry
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Mar 18, 2019
1 parent f3e5997 commit 592ac61
Showing 1 changed file with 59 additions and 78 deletions.
137 changes: 59 additions & 78 deletions src/arithmetic/modular_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,24 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)


// internal entry for const int bound
// This condition holds for all instances: coeff >= 0, base in [0, coeff]
struct ModularSetAnalyzer::Entry {
int64_t coeff{1};
int64_t base{0};

Entry() = default;

Entry(int64_t coeff, int64_t base) {
this->coeff = coeff;

CHECK_GE(coeff, 0);
if (coeff != 0) {
base = base % coeff;
if (base < 0) base += coeff;
}
this->base = base;
}

bool is_const() const {
return coeff == 0;
}
Expand All @@ -53,10 +67,7 @@ class ModularSetAnalyzer::Impl :
if (!override) {
CHECK(!var_map_.count(var));
}
Entry e;
e.coeff = info->coeff;
e.base = info->base;
var_map_[var] = e;
var_map_[var] = Entry(info->coeff, info->base);
}

// Detect useful constraints and use them in the analysis scope.
Expand All @@ -65,10 +76,7 @@ class ModularSetAnalyzer::Impl :
PVar<Integer> coeff, base;
// pattern match interesting constraints
if (((var % coeff) == base).Match(constraint)) {
Entry entry;
entry.coeff = coeff.Eval()->value;
entry.base = base.Eval()->value;
return UpdateByIntersect(var.Eval(), entry);
return UpdateByIntersect(var.Eval(), Entry(coeff.Eval()->value, base.Eval()->value));
}
return nullptr;
}
Expand All @@ -83,18 +91,12 @@ class ModularSetAnalyzer::Impl :
}

Entry VisitExpr_(const IntImm* op) final {
Entry ret;
ret.base = op->value;
ret.coeff = 0;
return ret;
return Entry(0, op->value);
}

Entry VisitExpr_(const UIntImm* op) final {
if (op->value < std::numeric_limits<int64_t>::max()) {
Entry ret;
ret.base = static_cast<int>(op->value);
ret.coeff = 0;
return ret;
return Entry(0, static_cast<int>(op->value));
} else {
return Everything();
}
Expand All @@ -103,19 +105,15 @@ class ModularSetAnalyzer::Impl :
Entry VisitExpr_(const Add* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
Entry ret;
ret.coeff = ZeroAwareGCD(a.coeff, b.coeff);
ret.base = BaseSimplify(a.base + b.base, ret.coeff);
return ret;
int64_t coeff = GCD(a.coeff, b.coeff);
return Entry(coeff, a.base + b.base);
}

Entry VisitExpr_(const Sub* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
Entry ret;
ret.coeff = ZeroAwareGCD(a.coeff, b.coeff);
ret.base = BaseSimplify(a.base - b.base, ret.coeff);
return ret;
int64_t coeff = GCD(a.coeff, b.coeff);
return Entry(coeff, a.base - b.base);
}

Entry VisitExpr_(const Mul* op) final {
Expand All @@ -128,10 +126,9 @@ class ModularSetAnalyzer::Impl :
int64_t pq = a.coeff * b.coeff;
int64_t pm = a.coeff * b.base;
int64_t qn = a.base * b.coeff;
Entry ret;
ret.coeff = ZeroAwareGCD(pq, ZeroAwareGCD(pm, qn));
ret.base = BaseSimplify(a.base * b.base, ret.coeff);
return ret;

int64_t coeff = GCD(pq, GCD(pm, qn));
return Entry(coeff, a.base * b.base);
}

Entry DivByConst(const Expr& lhs,
Expand All @@ -140,20 +137,15 @@ class ModularSetAnalyzer::Impl :
Entry a = VisitExpr(lhs);
CHECK_NE(val, 0);
if (a.coeff % val == 0) {
Entry ret;
if (a.base == 0) {
// a c x / c -> a x
ret.coeff = std::abs(a.coeff / val);
ret.base = 0;
return ret;
return Entry(std::abs(a.coeff / val), 0);
}
// positive division have a clear rounding mode.
// Only handle case where we clearly know we need to round down.
if (a.base > 0 && val > 0 &&
(round_down || parent_->CanProveGreaterEqual(lhs, 0))) {
ret.coeff = a.coeff / val;
ret.base = a.base / val;
return ret;
return Entry(a.coeff / val, a.base / val);
}
}
return Everything();
Expand Down Expand Up @@ -244,22 +236,17 @@ class ModularSetAnalyzer::Impl :
*/
static Entry Union(Entry a, Entry b) {
// {ax + y} \cup {bz + h} => {gcd(a, b) x + {y or h}}
int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff);
int64_t coeff = GCD(a.coeff, b.coeff);
if (coeff == 0) {
if (a.base == b.base) return a;
return Everything();
}
int64_t base0 = a.base % coeff;
int64_t base1 = b.base % coeff;
Entry ret;
if (base0 == base1) {
ret.coeff = coeff;
ret.base = base0;
return ret;
return Entry(coeff, base0);
} else {
ret.coeff = ZeroAwareGCD(ZeroAwareGCD(base0, base1), coeff);
ret.base = 0;
return ret;
return Entry(GCD(GCD(base0, base1), coeff), 0);
}
}
/*!
Expand All @@ -276,35 +263,20 @@ class ModularSetAnalyzer::Impl :
n = v / gcd * n;
m = v / gcd * (-m);

Entry ret;
ret.coeff = a / gcd * c;
ret.base = BaseSimplify(n * a + b, ret.coeff);
return ret;
int64_t coeff = a / gcd * c;
return Entry(coeff, n*a + b);
} else {
return Nothing();
}
}
/*!
* \brief Simplify base so that it is in [0, coeff) when coeff != 0.
* \param base The base value.
* \param coeff The coeff value.
* \return The simplified base.
*/
static int64_t BaseSimplify(int64_t base, int64_t coeff) {
if (coeff == 0) return base;
base = base % coeff;
if (base < 0) base += coeff;
return base;
}

/*!
* \brief Take GCD of a and b.
* \param a The first operand.
* \param b The second operand.
* \return The result.
*/
static int64_t ZeroAwareGCD(int64_t a, int64_t b) {
if (a < 0) a = -a;
if (b < 0) b = -b;
static int64_t GCD(int64_t a, int64_t b) {
if (a < b) std::swap(a, b);
if (b == 0) return a;
// perform GCD (greatest common divisor)
Expand All @@ -317,42 +289,51 @@ class ModularSetAnalyzer::Impl :
}

/*!
* \brief Use Extended Euclidean algorithm to solve ax + by = 1
* \param a The first operand.
* \param b The second operand.
* \brief Use Extended Euclidean algorithm to solve ax + by = gcd(a, b)
* \param a The first coefficient. (a >= 0)
* \param b The second coefficient. (b >= 0)
* \param x The solution of x.
* \param y The solution of y.
* \return The GCD of a and b.
*/
static int64_t ExtendedEuclidean(int64_t a, int64_t b, int64_t *x, int64_t *y) {
if (b == 0) {
*x = 1;
*y = 0;
return a;
int64_t s = 0, old_s = 1;
int64_t r = b, old_r = a;

while (r != 0) {
int64_t q = old_r / r;
int64_t tmp = old_r;
old_r = r;
r = tmp - q * r;
tmp = old_s;
old_s = s;
s = tmp - q * s;
}
int64_t q = ExtendedEuclidean(b, a % b, y, x);
*y -= a / b * (*x);
return q;

*x = old_s;
if (b != 0) {
*y = (old_r - old_s * a) / b;
} else {
*y = 1;
}

return old_r;
}

/*!
* \brief return everything dtype can represent.
* \return Bound that represent everything dtype can represent.
*/
static Entry Everything() {
Entry ret;
ret.coeff = 1; ret.base = 0;
return ret;
return Entry(1, 0);
}

/*!
* \brief return an empty set
* \return An empty modular set.
*/
static Entry Nothing() {
Entry ret;
ret.coeff = 0; ret.base = 1;
return ret;
return Entry(0, 1);
}
};

Expand Down

0 comments on commit 592ac61

Please sign in to comment.