Skip to content

Commit

Permalink
[Arith] add SizeVar representing non-neg valued variable in a tensor …
Browse files Browse the repository at this point in the history
…shape (apache#4684)

* [arith] add ShapeVar representing non-neg valued variable in a tensor shape

* bounder remover; deal with div in int_set differently

* fix bounder_remover

* migrate unittest to use shape_var

* use tvm.shape_var in integration & relay tests

* add test case; fix Var register

* fix lint

* fix lint again

* add default ShapeVar visitor in Relay

* fix override

* fix ShapeVar visit bug

* revert IntervalSet for shape_var

* remove bound_remover

* remove is_var; use constructor for shapevar/var instead

* ShapeVar -> SizeVar; add constructor comments

* shape_var -> size_var in doc

* tindex -> size
  • Loading branch information
yzhliu authored and alexwong committed Feb 26, 2020
1 parent d60d616 commit 4de6761
Show file tree
Hide file tree
Showing 62 changed files with 417 additions and 267 deletions.
2 changes: 2 additions & 0 deletions docs/api/python/tvm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ The user facing API for computation declaration.
tvm.load_json
tvm.save_json
tvm.var
tvm.size_var
tvm.const
tvm.convert
tvm.placeholder
Expand All @@ -49,6 +50,7 @@ The user facing API for computation declaration.
.. autofunction:: tvm.load_json
.. autofunction:: tvm.save_json
.. autofunction:: tvm.var
.. autofunction:: tvm.size_var
.. autofunction:: tvm.const
.. autofunction:: tvm.convert
.. autofunction:: tvm.placeholder
Expand Down
59 changes: 56 additions & 3 deletions include/tvm/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,27 +65,33 @@ class Var;
*/
class VarNode : public PrimExprNode {
public:
/*! \brief constructor */
VarNode() {}
VarNode(DataType dtype, std::string name_hint);

/*!
* \brief The hint to the variable name.
* \note Each variable is uniquely identified by its address.
*/
std::string name_hint;

static Var make(DataType dtype, std::string name_hint);

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("name", &name_hint);
}

static constexpr const char* _type_key = "Variable";
TVM_DECLARE_FINAL_OBJECT_INFO(VarNode, PrimExprNode);
TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode);
};

/*! \brief a named variable in TVM */
class Var : public PrimExpr {
public:
explicit Var(ObjectPtr<Object> n) : PrimExpr(n) {}
/*! \brief constructor
* \param name_hint variable name
* \param t data type
*/
TVM_DLL explicit Var(std::string name_hint = "v",
DataType t = DataType::Int(32));
/*!
Expand Down Expand Up @@ -114,6 +120,53 @@ class Var : public PrimExpr {
using ContainerType = VarNode;
};

class SizeVar;
/*!
* \brief A variable node represent a tensor index size,
* whose value must be non-negative.
*/
class SizeVarNode : public VarNode {
public:
/*! \brief constructor */
SizeVarNode() {}
/*! \brief constructor
* \param dtype data type
* \param name_hint variable name
*/
SizeVarNode(DataType dtype, std::string name_hint);

static constexpr const char* _type_key = "SizeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(SizeVarNode, VarNode);
};

/*! \brief a named variable represents a tensor index size */
class SizeVar : public Var {
public:
explicit SizeVar(ObjectPtr<Object> n) : Var(n) {}
/*! \brief constructor
* \param name_hint variable name
* \param t data type
*/
TVM_DLL explicit SizeVar(std::string name_hint = "s",
DataType t = DataType::Int(32));
/*!
* \brief Get pointer to the internal value.
* \return the corresponding Variable.
*/
const SizeVarNode* operator->() const {
return get();
}
/*!
* \brief Get pointer to the internal value.
* \return the corresponding Variable.
*/
const SizeVarNode* get() const {
return static_cast<const SizeVarNode*>(data_.get());
}
/*! \brief type indicate the container type */
using ContainerType = SizeVarNode;
};

/*!
* \brief Container of constant int that adds more constructors.
*
Expand Down
3 changes: 2 additions & 1 deletion include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ namespace ir {
using IntImmNode = tvm::IntImmNode;
using FloatImmNode = tvm::FloatImmNode;
using VarNode = tvm::VarNode;
using SizeVarNode = tvm::SizeVarNode;

/*! \brief String constants, only used in asserts. */
class StringImmNode : public PrimExprNode {
Expand Down Expand Up @@ -679,7 +680,7 @@ class AnyNode : public PrimExprNode {
void VisitAttrs(AttrVisitor* v) {}
/*! \brief Convert to var. */
Var ToVar() const {
return VarNode::make(DataType::Int(32), "any_dim");
return Var("any_dim", DataType::Int(32));
}

TVM_DLL static PrimExpr make();
Expand Down
6 changes: 6 additions & 0 deletions include/tvm/ir_functor_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
}
// Functions that can be overriden by subclass
virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const SizeVarNode* op, Args... args) {
return VisitExpr_(static_cast<const VarNode*>(op), std::forward<Args>(args)...);
}
virtual R VisitExpr_(const LoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
Expand Down Expand Up @@ -174,6 +177,7 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
FType vtable;
// Set dispatch
IR_EXPR_FUNCTOR_DISPATCH(VarNode);
IR_EXPR_FUNCTOR_DISPATCH(SizeVarNode);
IR_EXPR_FUNCTOR_DISPATCH(LoadNode);
IR_EXPR_FUNCTOR_DISPATCH(LetNode);
IR_EXPR_FUNCTOR_DISPATCH(CallNode);
Expand Down Expand Up @@ -297,6 +301,7 @@ class TVM_DLL ExprVisitor :
using ExprFunctor::VisitExpr;
// list of functions to override.
void VisitExpr_(const VarNode* op) override;
void VisitExpr_(const SizeVarNode* op) override;
void VisitExpr_(const LoadNode* op) override;
void VisitExpr_(const LetNode* op) override;
void VisitExpr_(const CallNode* op) override;
Expand Down Expand Up @@ -341,6 +346,7 @@ class TVM_DLL ExprMutator :
using ExprFunctor::VisitExpr;
// list of functions to override.
PrimExpr VisitExpr_(const VarNode* op) override;
PrimExpr VisitExpr_(const SizeVarNode* op) override;
PrimExpr VisitExpr_(const LoadNode* op) override;
PrimExpr VisitExpr_(const LetNode* op) override;
PrimExpr VisitExpr_(const CallNode* op) override;
Expand Down
19 changes: 19 additions & 0 deletions python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,25 @@ def var(name="tindex", dtype=int32):
return _api_internal._Var(name, dtype)


def size_var(name="size", dtype=int32):
"""Create a new variable represents a tensor shape size, which is non-negative.
Parameters
----------
name : str
The name
dtype : str
The data type
Returns
-------
var : SizeVar
The result symbolic shape variable.
"""
return _api_internal._SizeVar(name, dtype)


def any(*args):
"""Create a new experssion of the union of all conditions in the arguments
Expand Down
19 changes: 19 additions & 0 deletions python/tvm/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,25 @@ def __init__(self, name, dtype):
_api_internal._Var, name, dtype)


@register_object
class SizeVar(Var):
"""Symbolic variable to represent a tensor index size
which is greater or equal to zero
Parameters
----------
name : str
The name
dtype : int
The data type
"""
# pylint: disable=super-init-not-called
def __init__(self, name, dtype):
self.__init_handle_by_constructor__(
_api_internal._SizeVar, name, dtype)


@register_object
class Reduce(PrimExpr):
"""Reduce node.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/hybrid/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def visit_Call(self, node):
_internal_assert(func_id in list(HYBRID_GLOBALS.keys()) + \
['range', 'max', 'min', 'len'] + \
list(self.symbols.keys()), \
"Function call id not in intrinsics' list")
"Function call id " + func_id + " not in intrinsics' list")
for elem in node.args:
self.visit(elem)

Expand Down
7 changes: 6 additions & 1 deletion src/api/api_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@ namespace ir {

TVM_REGISTER_GLOBAL("_Var")
.set_body_typed([](std::string s, DataType t) {
return VarNode::make(t, s);
return Var(s, t);
});

TVM_REGISTER_GLOBAL("_SizeVar")
.set_body_typed([](std::string s, DataType t) {
return SizeVar(s, t);
});

TVM_REGISTER_GLOBAL("make.abs")
Expand Down
3 changes: 2 additions & 1 deletion src/arithmetic/bound_deducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class BoundDeducer: public ExprVisitor {

void VisitExpr(const PrimExpr& e) final {
if (!success_) return;
if (e.get() == path_[iter_++]) {
if (iter_ < path_.size() && e.get() == path_[iter_++]) {
ExprVisitor::VisitExpr(e);
} else {
success_ = false;
Expand Down Expand Up @@ -297,6 +297,7 @@ void BoundDeducer::Transform() {
void BoundDeducer::Deduce() {
Init();
if (!success_) return;

Relax();
if (!success_) return;
// get the path
Expand Down
10 changes: 10 additions & 0 deletions src/arithmetic/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,16 @@ class ConstIntBoundAnalyzer::Impl :
}
}

Entry VisitExpr_(const SizeVarNode* op) final {
SizeVar v = GetRef<SizeVar>(op);
auto it = var_map_.find(v);
if (it != var_map_.end()) {
return it->second;
} else {
return MakeBound(0, kPosInf);
}
}

Entry VisitRightShift(const CallNode* op) {
Entry a = VisitExpr(op->args[0]);
Entry b = VisitExpr(op->args[1]);
Expand Down
1 change: 1 addition & 0 deletions src/arithmetic/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ class IntervalSetEvaluator :
}
}


IntervalSet VisitExpr_(const AddNode* op) final {
return VisitBinaryExpr_(op);
}
Expand Down
4 changes: 4 additions & 0 deletions src/ir/attr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ class AttrFunctor<R(const ObjectRef& n, Args...)> {
virtual R VisitAttr_(const ir::StringImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
// deep comparison of symbolic integer expressions.
virtual R VisitAttr_(const VarNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const SizeVarNode* op, Args... args) {
return VisitAttr_(static_cast<const VarNode*>(op), std::forward<Args>(args)...);
}
virtual R VisitAttr_(const ir::AddNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::SubNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::MulNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
Expand Down Expand Up @@ -115,6 +118,7 @@ class AttrFunctor<R(const ObjectRef& n, Args...)> {
ATTR_FUNCTOR_DISPATCH(FloatImmNode);
ATTR_FUNCTOR_DISPATCH(StringImmNode);
ATTR_FUNCTOR_DISPATCH(VarNode);
ATTR_FUNCTOR_DISPATCH(SizeVarNode);
ATTR_FUNCTOR_DISPATCH(AddNode);
ATTR_FUNCTOR_DISPATCH(SubNode);
ATTR_FUNCTOR_DISPATCH(MulNode);
Expand Down
2 changes: 1 addition & 1 deletion src/ir/attrs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ bool AttrsEqualHandler::VisitAttr_(const StringImmNode* lhs, const ObjectRef& ot
bool AttrsEqualHandler::VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<ArrayNode>()) {
if (rhs->data.size() != lhs->data.size()) return false;
for (size_t i = 0; i < lhs->data.size(); ++i) {
for (size_t i = 0; i < lhs->data.size(); ++i) {
if (!Equal(lhs->data[i], rhs->data[i])) return false;
}
}
Expand Down
16 changes: 10 additions & 6 deletions src/lang/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,19 @@ PrimExpr::PrimExpr(std::string str)
: PrimExpr(ir::StringImmNode::make(str)) {}

Var::Var(std::string name_hint, DataType t)
: Var(VarNode::make(t, name_hint)) {}
: Var(make_object<VarNode>(t, name_hint)) {}

Var VarNode::make(DataType t, std::string name_hint) {
ObjectPtr<VarNode> node = make_object<VarNode>();
node->dtype = t;
node->name_hint = std::move(name_hint);
return Var(node);
VarNode::VarNode(DataType t, std::string name_hint) {
this->dtype = t;
this->name_hint = std::move(name_hint);
}

SizeVar::SizeVar(std::string name_hint, DataType t)
: SizeVar(make_object<SizeVarNode>(t, name_hint)) {}

SizeVarNode::SizeVarNode(DataType t, std::string name_hint)
: VarNode(t, std::move(name_hint)) {}

Range::Range(PrimExpr begin, PrimExpr end)
: Range(make_object<RangeNode>(
begin,
Expand Down
5 changes: 5 additions & 0 deletions src/lang/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,10 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
// stream << op->name << "." << op->type;
p->stream << op->name_hint;
})
.set_dispatch<SizeVarNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const SizeVarNode*>(node.get());
p->stream << "{" << op->name_hint << "|" << op->name_hint << ">=0}";
})
.set_dispatch<AddNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const AddNode*>(node.get());
p->stream << '(';
Expand Down Expand Up @@ -1143,6 +1147,7 @@ TVM_REGISTER_NODE_TYPE(IntImmNode);
TVM_REGISTER_NODE_TYPE(StringImmNode);
TVM_REGISTER_NODE_TYPE(CastNode);
TVM_REGISTER_NODE_TYPE(VarNode);
TVM_REGISTER_NODE_TYPE(SizeVarNode);
TVM_REGISTER_NODE_TYPE(AddNode);
TVM_REGISTER_NODE_TYPE(SubNode);
TVM_REGISTER_NODE_TYPE(MulNode);
Expand Down
8 changes: 8 additions & 0 deletions src/pass/ir_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@ void StmtVisitor::VisitStmt_(const EvaluateNode* op) {

void ExprVisitor::VisitExpr_(const VarNode* op) {}

void ExprVisitor::VisitExpr_(const SizeVarNode* op) {
this->VisitExpr_(static_cast<const VarNode*>(op));
}

void ExprVisitor::VisitExpr_(const LoadNode* op) {
this->VisitExpr(op->index);
this->VisitExpr(op->predicate);
Expand Down Expand Up @@ -596,6 +600,10 @@ PrimExpr ExprMutator::VisitExpr_(const VarNode* op) {
return GetRef<PrimExpr>(op);
}

PrimExpr ExprMutator::VisitExpr_(const SizeVarNode* op) {
return this->VisitExpr_(static_cast<const VarNode*>(op));
}

PrimExpr ExprMutator::VisitExpr_(const LoadNode* op) {
PrimExpr index = this->VisitExpr(op->index);
PrimExpr predicate = this->VisitExpr(op->predicate);
Expand Down
Loading

0 comments on commit 4de6761

Please sign in to comment.