Skip to content

Commit

Permalink
[ARITH] Analyzer Infra, ConstIntBound, Modular (apache#2668)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored and wweic committed Mar 12, 2019
1 parent 47a344c commit 8eb134b
Show file tree
Hide file tree
Showing 22 changed files with 1,777 additions and 389 deletions.
326 changes: 270 additions & 56 deletions include/tvm/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,282 @@
#include <vector>
#include <unordered_map>
#include <memory>
#include <limits>
#include "expr.h"

namespace tvm {

// forward delcare Tensor
class Tensor;

/*! \brief namespace of arithmetic */
namespace arith {
//-------------------------------------------------------
// Base integer analysis API.
//
// We have multiple type of analyzers to do relaxed
// integer set analysis(bound analysis, modulo) and
// equivalence checking and simplification.
//
// Importantly, each analyzer may need result from
// another analyzer.
//-------------------------------------------------------

// Forward declare Analyzer
class Analyzer;
/*!
* \brief reference class to ConstIntBoundNode
* \sa ConstIntBoundNode
*/
class ConstIntBound;
/*!
* \brief Constant integer up and lower bound(inclusive).
* Useful for value bound analysis.
*
* set = [min_value, max_value]
*/
class ConstIntBoundNode : public Node {
public:
int64_t min_value;
int64_t max_value;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("min_value", &min_value);
v->Visit("max_value", &max_value);
}

TVM_DLL static ConstIntBound make(int64_t min_value, int64_t max_value);

/*! \brief Number to represent +inf */
static const constexpr int64_t kPosInf = std::numeric_limits<int64_t>::max();
/*!
* \brief Number to represent -inf
* \note We can make use the of fact that -kPosInf == kNegInf in the project.
*/
static const constexpr int64_t kNegInf = -kPosInf;

static constexpr const char* _type_key = "arith.ConstIntBound";
TVM_DECLARE_NODE_TYPE_INFO(ConstIntBoundNode, Node);
};

TVM_DEFINE_NODE_REF(ConstIntBound, ConstIntBoundNode);

/*!
* \brief Analyzer to get constant integer bound over expression.
*/
class ConstIntBoundAnalyzer {
public:
/*!
* \brief analyze the expr
* \param expr The expression of interest.
* \return the result of the analysis.
*/
ConstIntBound operator()(const Expr& expr);

/*!
* \brief Update constant int bound information of var.
*
* \param var The variable of interest.
* \param info The bound information.
* \param override Whether do we allow override of existing information.
*/
void Update(const Var& var,
const ConstIntBound& info,
bool override = false);
/*!
* \brief Bind variable to a range.
*
* \param var The variable.
* \param range The range we bind to.
*/
void Bind(const Var& var, const Range& range);

private:
friend class Analyzer;
friend class ConstraintContext;
explicit ConstIntBoundAnalyzer(Analyzer* parent);
~ConstIntBoundAnalyzer();
/*!
* \brief Update the internal state to enter constraint.
* \param constraint A constraint expression.
*
* \return an exit function that must be called to cleanup the constraint can be nullptr.
*/
std::function<void()> EnterConstraint(const Expr& constraint);
struct Entry;
class Impl;
/*! \brief Internal impl */
Impl* impl_;
};

/*!
* \brief reference of ModularSetNode
* \sa ModularSetNode
*/
class ModularSet;
/*!
* \brief Range of a linear integer function.
* Use to do specify the possible index values.
*
* set = { coeff * x + base | x in Z }
*
* When coeff != 0, it can also be written as
* set = { n | n % coeff == base }
*
* This is useful to decide if the index is dividable by certain value.
* For example, if index = 0 + 4 x, then we know it can be divided by 4.
*/
class ModularSetNode : public Node {
public:
/*! \brief linear co-efficient */
int64_t coeff;
/*! \brief The base */
int64_t base;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("coeff", &coeff);
v->Visit("base", &base);
}

TVM_DLL static ModularSet make(int64_t coeff, int64_t base);

static constexpr const char* _type_key = "arith.ModularSet";
TVM_DECLARE_NODE_TYPE_INFO(ModularSetNode, Node);
};

TVM_DEFINE_NODE_REF(ModularSet, ModularSetNode);

/*!
* \brief Analyzer to get modular information over expression.
*/
class ModularSetAnalyzer {
public:
/*!
* \brief analyze the expr
* \param expr The expression of interest.
* \return the result of the analysis.
*/
ModularSet operator()(const Expr& expr);
/*!
* \brief Update constant int bound information of var.
*
* \param var The variable of interest.
* \param info The bound information.
* \param override Whether do we allow override of existing information.
*/
void Update(const Var& var,
const ModularSet& info,
bool override = false);

private:
friend class Analyzer;
friend class ConstraintContext;
explicit ModularSetAnalyzer(Analyzer* parent);
~ModularSetAnalyzer();
/*!
* \brief Update the internal state to enter constraint.
* \param constraint A constraint expression.
*
* \return an exit function that must be called to cleanup the constraint can be nullptr.
*/
std::function<void()> EnterConstraint(const Expr& constraint);
struct Entry;
class Impl;
/*! \brief Internal impl */
Impl* impl_;
};

/*!
* \brief A RAII constraint context.
*
* \code
*
* Var("x");
* arith::Analyzer analyzer;
* {
* arith::ConstraintContext cctx(&analyzer, x % 3 == 0);
* CHECK_EQ(analyzer.modular_set(x)->coeff, 3);
* }
* // constraint no longer in effect.
* CHECK_NE(analyzer.modular_set(x)->coeff, 3);
*
* \endcode
*/
class ConstraintContext {
public:
/*!
* \brief Construct a constraint context.
* \param analyzer The analyzer.
* \param constraint The constraint to be applied.
*/
ConstraintContext(Analyzer* analyzer, const Expr& constraint) DMLC_THROW_EXCEPTION;
/*! \brief destructor */
~ConstraintContext() DMLC_THROW_EXCEPTION {
exit_();
}

private:
/*! \brief function to be called in recovery */
std::function<void()> exit_;
};

/*!
* \brief Analyzer that contains bunch of sub-analyzers.
*
* Each sub-analyzer can make use of another sub-analyzer
* by weak reference of this.
*
* NOTE for sub-analyzer developers:
* If the analyzer uses memoization, we need to clear the internal
* cache when information about a Var has been overrideen.
*/
class Analyzer {
public:
/*! \brief sub-analyzer: const integer bound */
ConstIntBoundAnalyzer const_int_bound;
/*! \brief sub-analyzer: modular set */
ModularSetAnalyzer modular_set;
/*! \brief constructor */
Analyzer();
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to expr.
*
* Each var can only be binded once.
*
* \param var The variable.
* \param expr The expression we bind to.
*/
void Bind(const VarExpr& var, const Expr& expr);
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to a range.
*
* Each var can only be binded once.
*
* \param var The variable.
* \param range The range we bind to.
*/
void Bind(const VarExpr& var, const Range& range);
/*!
* \brief Whether can we proof expr >= val.
* Non-negative proof is very useful in integer analysis
* to lower divisions and mods given difference in trunc and ceil mode.
*
* \param expr The expression.
* \param lower_bound The lower bound.
* \return Whether we can proof it.
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
bool CanProveGreaterEqual(const Expr& expr, int64_t lower_bound);
};

//-----------------------------------------------
// Integer set abstraction API.
//
// This is a API build on top of the base
// integer analysis API to provide set analysis.
//------------------------------------------------
/*!
* \brief Sign of an expression or set.
*/
Expand Down Expand Up @@ -118,42 +386,6 @@ class IntSet : public NodeRef {
static IntSet interval(Expr min, Expr max);
};

/*!
* \brief Range of a linear integer function.
* Use to do specify the possible index values.
*
* set = { coeff * x + base | x in Z }
*
* When coeff != 0, it can also be written as
* set = { n | n % coeff == base }
*
* This is useful to decide if the index is dividable by certain value.
* For example, if index = 0 + 4 x, then we know it can be divided by 4.
*/
struct ModularEntry {
/*! \brief linear co-efficient */
int coeff{1};
/*! \brief The base */
int base{0};

/*! \return entry represent everything */
static ModularEntry everything() {
// always safe to set 0 + x, so it can be everything.
ModularEntry e;
e.coeff = 1;
e.base = 0;
return e;
}
/*!
* \brief Add two modular entries together to get a new modular entry.
* \param a The left operand.
* \param b The right operand.
* \return The combined modular entry.
*/
static ModularEntry Add(const ModularEntry& a,
const ModularEntry& b);
};

/*!
* \brief Base class of all IntSet containers.
*/
Expand Down Expand Up @@ -300,24 +532,6 @@ IntSet DeduceBound(Expr v, Expr cond,
*/
Domain DomainTouched(Stmt body, const Tensor &tensor, bool consider_calls, bool consider_provides);

/*!
* \brief Evaluate the expression with modular analysis
* \param e The expression to be evaluated.
* \param mod_map Map of modular statistics of known variables.
* \return The ModularEntry covering all possible value of e.
*/
ModularEntry EvalModular(
const Expr& e,
const std::unordered_map<const Variable*, ModularEntry>& mod_map);

/*!
* \brief Same as EvalModular, used by front-end.
* \param e The expression to be evaluated.
* \param mod_map Map of modular statistics of known variables.
* \return A ModularSet covering all possible value of e.
*/
IntSet EvalModular(const Expr& e,
const Map<Var, IntSet>& mod_map);
// implementation
inline const IntSetNode* IntSet::operator->() const {
return static_cast<const IntSetNode*>(node_.get());
Expand Down
1 change: 1 addition & 0 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
namespace tvm {
namespace ir {

using HalideIR::Internal::BaseExprNode;
using HalideIR::Internal::ExprNode;
using HalideIR::Internal::StmtNode;
using HalideIR::Internal::IRNodeType;
Expand Down
Loading

0 comments on commit 8eb134b

Please sign in to comment.