Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ARITH] Analyzer RewriteSimplifier: add/sub/mul/div/mod #2722

Merged
merged 2 commits into from
Mar 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions include/tvm/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,39 @@ class ModularSetAnalyzer {
Impl* impl_;
};

/*!
* \brief Rewrite-rule based simplifier.
*/
class RewriteSimplifier {
public:
/*!
* \brief analyze the expr
* \param expr The expression of interest.
* \return the result of the analysis.
*/
Expr operator()(const Expr& expr);

/*!
* \brief Update binding of var to a new expression.
*
* \param var The variable of interest.
* \param new_expr
* \param override Whether do we allow override of existing information.
*/
void Update(const Var& var,
const Expr& new_expr,
bool override = false);

private:
friend class Analyzer;
friend class ConstraintContext;
explicit RewriteSimplifier(Analyzer* parent);
~RewriteSimplifier();
class Impl;
/*! \brief Internal impl */
Impl* impl_;
};

/*!
* \brief A RAII constraint context.
*
Expand Down Expand Up @@ -242,6 +275,8 @@ class Analyzer {
ConstIntBoundAnalyzer const_int_bound;
/*! \brief sub-analyzer: modular set */
ModularSetAnalyzer modular_set;
/*! \brief sub-analyzer rewrite simplfy */
RewriteSimplifier rewrite_simplify;
/*! \brief constructor */
Analyzer();
/*!
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(self):
self._const_int_bound_update = _mod("const_int_bound_update")
self._bind = _mod("bind")
self._modular_set = _mod("modular_set")
self._rewrite_simplify = _mod("rewrite_simplify")
self._enter_constraint_context = _mod("enter_constraint_context")

def const_int_bound(self, expr):
Expand Down Expand Up @@ -128,6 +129,21 @@ def modular_set(self, expr):
"""
return self._modular_set(expr)

def rewrite_simplify(self, expr):
"""Simplify expression via rewriting rules.
Parameters
----------
expr : tvm.Expr
The expression.
Returns
-------
result : Expr
The result.
"""
return self._rewrite_simplify(expr)

def bind(self, var, expr):
"""Bind a variable to the expression.
Expand Down
4 changes: 4 additions & 0 deletions src/api/api_arith.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ TVM_REGISTER_API("arith._CreateAnalyzer")
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
self->const_int_bound.Update(args[0], args[1], args[2]);
});
} else if (name == "rewrite_simplify") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->rewrite_simplify(args[0]);
});
} else if (name == "bind") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
auto& sptr = args[1].node_sptr();
Expand Down
11 changes: 9 additions & 2 deletions src/arithmetic/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,30 @@
* Copyright (c) 2019 by Contributors
* \file tvm/arithmetic/analyzer.cc
*/
#include <tvm/ir.h>
#include <tvm/arithmetic.h>

namespace tvm {
namespace arith {

Analyzer::Analyzer()
: const_int_bound(this),
modular_set(this) {
modular_set(this),
rewrite_simplify(this) {
}

void Analyzer::Bind(const VarExpr& v, const Expr& expr) {
Var var(v.node_);
this->const_int_bound.Update(var, this->const_int_bound(expr));
this->modular_set.Update(var, this->modular_set(expr));
this->rewrite_simplify.Update(var, this->rewrite_simplify(expr));
}

void Analyzer::Bind(const VarExpr& v, const Range& range) {
Var var(v.node_);
this->const_int_bound.Bind(var, range);
// skip modular_set
// skip rewrite simplify
}

ConstraintContext::ConstraintContext(Analyzer* analyzer, const Expr& constraint) {
Expand All @@ -36,7 +40,10 @@ ConstraintContext::ConstraintContext(Analyzer* analyzer, const Expr& constraint)
}

bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) {
auto bd = this->const_int_bound(expr);
if (const auto* ptr = expr.as<ir::IntImm>()) {
return ptr->value > lower_bound;
}
auto bd = this->const_int_bound(this->rewrite_simplify(expr));
if (bd->min_value >= lower_bound) return true;
return false;
}
Expand Down
4 changes: 3 additions & 1 deletion src/arithmetic/const_fold.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ namespace arith {
* \return nullptr if constant fold fails, otherwise return folded result.
*/
template<typename Op>
inline Expr TryConstFold(Expr a, Expr b);
inline Expr TryConstFold(Expr a, Expr b) {
return Expr();
}

/*!
* \brief Try to run unary compute with constant folding.
Expand Down
51 changes: 46 additions & 5 deletions src/arithmetic/pattern_match.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@

#include <tvm/ir_pass.h>
#include <tuple>
#include "const_fold.h"

namespace tvm {
namespace arith {
Expand Down Expand Up @@ -242,20 +243,60 @@ class PBinaryExpr :
}

Expr Eval() const {
return NodeType::make(a_.Eval(), b_.Eval());
Expr lhs = a_.Eval();
Expr rhs = b_.Eval();
Expr ret = TryConstFold<NodeType>(lhs, rhs);
if (ret.defined()) return ret;
return NodeType::make(lhs, rhs);
}

private:
typename TA::Nested a_;
typename TB::Nested b_;
};

template<typename TA>
class PConstWithTypeLike :
public Pattern<PConstWithTypeLike<TA> > {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a class description and some description for its constructor parameters.

public:
PConstWithTypeLike(const TA& ref, int64_t value)
: ref_(ref), value_(value) {}

void InitMatch_() const {}

bool Match_(const NodeRef& node) const {
if (const ir::IntImm* ptr = node.as<ir::IntImm>()) {
return ptr->value == value_;
} else {
return false;
}
}

Expr Eval() const {
return make_const(ref_.Eval().type(), value_);
}

private:
typename TA::Nested ref_;
int64_t value_;
};


#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) \
template<typename TA, typename TB> \
inline PBinaryExpr<NodeName, TA, TB> \
FuncName(const Pattern<TA>& a, const Pattern<TB>& b) { \
#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) \
template<typename TA, typename TB> \
inline PBinaryExpr<NodeName, TA, TB> \
FuncName(const Pattern<TA>& a, const Pattern<TB>& b) { \
return PBinaryExpr<NodeName, TA, TB>(a.derived(), b.derived()); \
} \
template<typename TA> \
inline PBinaryExpr<NodeName, TA, PConstWithTypeLike<TA> > \
FuncName(const Pattern<TA>& a, int64_t b) { \
return FuncName(a, PConstWithTypeLike<TA>(a.derived(), b)); \
} \
template<typename TA> \
inline PBinaryExpr<NodeName, PConstWithTypeLike<TA>, TA> \
FuncName(int64_t b, const Pattern<TA>& a) { \
return FuncName(PConstWithTypeLike<TA>(a.derived(), b), a); \
}

// arithmetic expressions
Expand Down
Loading