diff --git a/include/tvm/attrs.h b/include/tvm/attrs.h index 33d84cecec6a..51d916ca488d 100644 --- a/include/tvm/attrs.h +++ b/include/tvm/attrs.h @@ -35,6 +35,7 @@ #include #include "ir.h" #include "base.h" +#include "expr.h" #include "packed_func_ext.h" namespace tvm { @@ -73,7 +74,6 @@ inline Type NullValue() { return Type(Type::Handle, 0, 0); } - /*! \brief Error thrown during attribute checking. */ struct AttrError : public dmlc::Error { /*! diff --git a/include/tvm/expr.h b/include/tvm/expr.h index 7fdca7f6af8e..37b122ae5b03 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -29,6 +29,7 @@ using HalideIR::VarExpr; using HalideIR::IR::RangeNode; using HalideIR::IR::FunctionRef; using HalideIR::IR::FunctionBaseNode; +using HalideIR::Internal::IntImm; using HalideIR::Internal::Stmt; using HalideIR::Internal::IRPrinter; using HalideIR::Internal::Variable; @@ -83,6 +84,51 @@ class Var : public HalideIR::VarExpr { }; +/*! + * \brief Container of constant ineteger (IntImm). + * + * This is used to store and automate type check + * attributes that must be constant integer. + */ +class Integer : public Expr { + public: + Integer() : Expr() {} + /*! + * \brief constructor from node. + */ + explicit Integer(NodePtr node) : Expr(node) {} + /*! + * \brief Construct integer from int value. + */ + Integer(int value) : Expr(value) {} // NOLINT(*) + /*! + * \brief Assign an expression to integer. + * \param other another expression. + */ + Integer& operator=(const Integer& other) { + node_ = other.node_; + return *this; + } + /*! + * \brief Get pointer to the internal value. + * \return the content of the integer. + */ + const IntImm* operator->() const { + return static_cast(node_.get()); + } + /*! + * \brief convert to int64_t + */ + operator int64_t() const { + CHECK(node_ != nullptr) + << " Trying get reference a null Integer"; + return (*this)->value; + } + /*! \brief type indicate the container type */ + using ContainerType = IntImm; +}; + + /*! \brief container class of iteration variable. */ class IterVarNode; diff --git a/include/tvm/packed_func_ext.h b/include/tvm/packed_func_ext.h index 0491f3057815..c5a83608c617 100644 --- a/include/tvm/packed_func_ext.h +++ b/include/tvm/packed_func_ext.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #include "base.h" @@ -126,6 +127,8 @@ inline TNodeRef TVMArgValue::AsNodeRef() const { inline TVMArgValue::operator HalideIR::Expr() const { if (type_code_ == kNull) return Expr(); if (type_code_ == kDLInt) { + CHECK_LE(value_.v_int64, std::numeric_limits::max()); + CHECK_GE(value_.v_int64, std::numeric_limits::min()); return Expr(static_cast(value_.v_int64)); } if (type_code_ == kDLFloat) { @@ -145,6 +148,20 @@ inline TVMArgValue::operator HalideIR::Expr() const { return Expr(sptr); } +inline TVMArgValue::operator tvm::Integer() const { + if (type_code_ == kNull) return Integer(); + if (type_code_ == kDLInt) { + CHECK_LE(value_.v_int64, std::numeric_limits::max()); + CHECK_GE(value_.v_int64, std::numeric_limits::min()); + return Integer(static_cast(value_.v_int64)); + } + NodePtr& sptr = *ptr >(); + CHECK(NodeTypeChecker::Check(sptr.get())) + << "Expected type " << NodeTypeName() + << " but get " << sptr->type_key(); + return Integer(sptr); +} + inline NodePtr& TVMArgValue::node_sptr() { TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); return *ptr >(); diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index eb044ccb29fd..34bd5eb93312 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -317,7 +317,7 @@ struct BatchNormAttrs : public tvm::AttrsNode { /*! \brief Attributes for LRN operator */ struct LRNAttrs : public tvm::AttrsNode { IndexExpr size; - IndexExpr axis; + int axis; double bias; double alpha; double beta; @@ -340,7 +340,7 @@ struct LRNAttrs : public tvm::AttrsNode { /*! \brief Attributes for L2Normalize operator */ struct L2NormalizeAttrs : public tvm::AttrsNode { double eps; - Array axis; + Array axis; TVM_DECLARE_ATTRS(L2NormalizeAttrs, "relay.attrs.L2NormalizeAttrs") { TVM_ATTR_FIELD(eps) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 1941e045ed8d..b0150c4ac3d9 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -53,7 +53,7 @@ struct ConcatenateAttrs : public tvm::AttrsNode { /*! \brief Attributes used in transpose operators */ struct TransposeAttrs : public tvm::AttrsNode { - Array axes; + Array axes; TVM_DECLARE_ATTRS(TransposeAttrs, "relay.attrs.TransposeAttrs") { TVM_ATTR_FIELD(axes) .describe("The target axes order, reverse order if not specified."); @@ -70,10 +70,10 @@ struct ReshapeAttrs : public tvm::AttrsNode { }; // struct ReshapeAttrs struct TakeAttrs : public tvm::AttrsNode { - IndexExpr axis; + Integer axis; TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") { - TVM_ATTR_FIELD(axis).set_default(NullValue()) + TVM_ATTR_FIELD(axis).set_default(NullValue()) .describe("The axis over which to select values."); } }; diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index a8fa096e51c4..c306f8d15160 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -32,6 +32,9 @@ struct Expr; #endif namespace tvm { +// forward declarations +class Integer; + namespace runtime { // forward declarations class TVMArgs; @@ -559,6 +562,7 @@ class TVMArgValue : public TVMPODValue_ { inline bool IsNodeType() const; inline operator HalideIR::Type() const; inline operator HalideIR::Expr() const; + inline operator tvm::Integer() const; // get internal node ptr, if it is node inline NodePtr& node_sptr(); }; diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 8459a99cde23..d38c5a0ebe0d 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -317,7 +317,7 @@ TVM_REGISTER_NODE_TYPE(LRNAttrs); Expr MakeLRN(Expr data, IndexExpr size, - IndexExpr axis, + int axis, double alpha, double beta, double bias) { @@ -337,7 +337,7 @@ TVM_REGISTER_API("relay.op.nn._make.lrn") }); RELAY_REGISTER_OP("nn.lrn") - .describe(R"code(LRN layer. +.describe(R"code(LRN layer. Normalize the input in a local region across or within feature maps. Each input value is divided by (1 + (\alpha/n) \sum_i x_i^2)^\beta, @@ -362,7 +362,7 @@ TVM_REGISTER_NODE_TYPE(L2NormalizeAttrs); Expr MakeL2Normalize(Expr data, double eps, - Array axis) { + Array axis) { auto attrs = make_node(); attrs->eps = eps; attrs->axis = std::move(axis); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index bab875fd190e..29dff1e4ba27 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -218,24 +218,23 @@ bool TransposeRel(const Array& types, } const auto* param = attrs.as(); const int ndim = data->shape.size(); - const Array& axes = param->axes; + const Array& axes = param->axes; // check dimension match - CHECK(axes.empty() || static_cast(axes.size()) == ndim) + CHECK(!axes.defined() || static_cast(axes.size()) == ndim) << "Dimension mismatch: axes has " << axes.size() << " elements" << ", but data.ndim = " << ndim; // construct int_axes std::vector int_axes; int_axes.reserve(ndim); - if (axes.empty()) { + // used not defined to check if it is None. + if (!axes.defined()) { for (int i = ndim - 1; i >= 0; --i) { int_axes.push_back(i); } } else { std::vector axis_used(ndim, 0); - for (const IndexExpr& e : axes) { - const int64_t *axis_ptr = as_const_int(e); - CHECK(axis_ptr != nullptr); - int axis = *axis_ptr; + for (const Integer& e : axes) { + int64_t axis = e; // sanity check for axis and ndim CHECK(-ndim <= axis && axis < ndim) << "transpose only allows each `axis` in `axes` in range [-data.ndim, data.ndim)" @@ -245,7 +244,7 @@ bool TransposeRel(const Array& types, // sanity check for duplication CHECK(!axis_used[axis]) << "Duplicate axes in transpose: " << axis; axis_used[axis] = 1; - int_axes.push_back(axis); + int_axes.push_back(static_cast(axis)); } } std::vector oshape; @@ -258,7 +257,7 @@ bool TransposeRel(const Array& types, } Expr MakeTranspose(Expr data, - Array axes) { + Array axes) { auto attrs = make_node(); attrs->axes = std::move(axes); static const Op& op = Op::Get("transpose"); @@ -401,7 +400,7 @@ bool TakeRel(const Array& types, std::vector oshape; const auto ndim_data = static_cast(data->shape.size()); const auto ndim_indices = static_cast(indices->shape.size()); - auto axis = (*as_const_int(param->axis)); + int axis = static_cast(param->axis->value); if (axis < 0) axis += ndim_data; CHECK_LE(axis, ndim_data) << "axis should be with in data shape" @@ -424,9 +423,9 @@ bool TakeRel(const Array& types, Expr MakeTake(Expr data, Expr indices, - IndexExpr axis) { + Integer axis) { auto attrs = make_node(); - attrs->axis = axis; + attrs->axis = std::move(axis); static const Op& op = Op::Get("take"); return CallNode::make(op, {data, indices}, Attrs(attrs), {}); }