Skip to content

Commit

Permalink
[ATTR] Introduce Integer container (apache#1994)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Oct 25, 2018
1 parent b20a6d4 commit a71b34d
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 21 deletions.
2 changes: 1 addition & 1 deletion include/tvm/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include <string>
#include "ir.h"
#include "base.h"
#include "expr.h"
#include "packed_func_ext.h"

namespace tvm {
Expand Down Expand Up @@ -73,7 +74,6 @@ inline Type NullValue<Type>() {
return Type(Type::Handle, 0, 0);
}


/*! \brief Error thrown during attribute checking. */
struct AttrError : public dmlc::Error {
/*!
Expand Down
46 changes: 46 additions & 0 deletions include/tvm/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ using HalideIR::VarExpr;
using HalideIR::IR::RangeNode;
using HalideIR::IR::FunctionRef;
using HalideIR::IR::FunctionBaseNode;
using HalideIR::Internal::IntImm;
using HalideIR::Internal::Stmt;
using HalideIR::Internal::IRPrinter;
using HalideIR::Internal::Variable;
Expand Down Expand Up @@ -83,6 +84,51 @@ class Var : public HalideIR::VarExpr {
};


/*!
* \brief Container of constant ineteger (IntImm).
*
* This is used to store and automate type check
* attributes that must be constant integer.
*/
class Integer : public Expr {
public:
Integer() : Expr() {}
/*!
* \brief constructor from node.
*/
explicit Integer(NodePtr<Node> node) : Expr(node) {}
/*!
* \brief Construct integer from int value.
*/
Integer(int value) : Expr(value) {} // NOLINT(*)
/*!
* \brief Assign an expression to integer.
* \param other another expression.
*/
Integer& operator=(const Integer& other) {
node_ = other.node_;
return *this;
}
/*!
* \brief Get pointer to the internal value.
* \return the content of the integer.
*/
const IntImm* operator->() const {
return static_cast<const IntImm*>(node_.get());
}
/*!
* \brief convert to int64_t
*/
operator int64_t() const {
CHECK(node_ != nullptr)
<< " Trying get reference a null Integer";
return (*this)->value;
}
/*! \brief type indicate the container type */
using ContainerType = IntImm;
};


/*! \brief container class of iteration variable. */
class IterVarNode;

Expand Down
17 changes: 17 additions & 0 deletions include/tvm/packed_func_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <sstream>
#include <string>
#include <memory>
#include <limits>
#include <type_traits>

#include "base.h"
Expand Down Expand Up @@ -126,6 +127,8 @@ inline TNodeRef TVMArgValue::AsNodeRef() const {
inline TVMArgValue::operator HalideIR::Expr() const {
if (type_code_ == kNull) return Expr();
if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
CHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
return Expr(static_cast<int>(value_.v_int64));
}
if (type_code_ == kDLFloat) {
Expand All @@ -145,6 +148,20 @@ inline TVMArgValue::operator HalideIR::Expr() const {
return Expr(sptr);
}

inline TVMArgValue::operator tvm::Integer() const {
if (type_code_ == kNull) return Integer();
if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
CHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
return Integer(static_cast<int>(value_.v_int64));
}
NodePtr<Node>& sptr = *ptr<NodePtr<Node> >();
CHECK(NodeTypeChecker<Integer>::Check(sptr.get()))
<< "Expected type " << NodeTypeName<Expr>()
<< " but get " << sptr->type_key();
return Integer(sptr);
}

inline NodePtr<Node>& TVMArgValue::node_sptr() {
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
return *ptr<NodePtr<Node> >();
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ struct BatchNormAttrs : public tvm::AttrsNode<BatchNormAttrs> {
/*! \brief Attributes for LRN operator */
struct LRNAttrs : public tvm::AttrsNode<LRNAttrs> {
IndexExpr size;
IndexExpr axis;
int axis;
double bias;
double alpha;
double beta;
Expand All @@ -340,7 +340,7 @@ struct LRNAttrs : public tvm::AttrsNode<LRNAttrs> {
/*! \brief Attributes for L2Normalize operator */
struct L2NormalizeAttrs : public tvm::AttrsNode<L2NormalizeAttrs> {
double eps;
Array<IndexExpr> axis;
Array<Integer> axis;

TVM_DECLARE_ATTRS(L2NormalizeAttrs, "relay.attrs.L2NormalizeAttrs") {
TVM_ATTR_FIELD(eps)
Expand Down
6 changes: 3 additions & 3 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ struct ConcatenateAttrs : public tvm::AttrsNode<ConcatenateAttrs> {

/*! \brief Attributes used in transpose operators */
struct TransposeAttrs : public tvm::AttrsNode<TransposeAttrs> {
Array<IndexExpr> axes;
Array<Integer> axes;
TVM_DECLARE_ATTRS(TransposeAttrs, "relay.attrs.TransposeAttrs") {
TVM_ATTR_FIELD(axes)
.describe("The target axes order, reverse order if not specified.");
Expand All @@ -70,10 +70,10 @@ struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
}; // struct ReshapeAttrs

struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
IndexExpr axis;
Integer axis;

TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") {
TVM_ATTR_FIELD(axis).set_default(NullValue<IndexExpr>())
TVM_ATTR_FIELD(axis).set_default(NullValue<Integer>())
.describe("The axis over which to select values.");
}
};
Expand Down
4 changes: 4 additions & 0 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ struct Expr;
#endif

namespace tvm {
// forward declarations
class Integer;

namespace runtime {
// forward declarations
class TVMArgs;
Expand Down Expand Up @@ -559,6 +562,7 @@ class TVMArgValue : public TVMPODValue_ {
inline bool IsNodeType() const;
inline operator HalideIR::Type() const;
inline operator HalideIR::Expr() const;
inline operator tvm::Integer() const;
// get internal node ptr, if it is node
inline NodePtr<Node>& node_sptr();
};
Expand Down
6 changes: 3 additions & 3 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ TVM_REGISTER_NODE_TYPE(LRNAttrs);

Expr MakeLRN(Expr data,
IndexExpr size,
IndexExpr axis,
int axis,
double alpha,
double beta,
double bias) {
Expand All @@ -337,7 +337,7 @@ TVM_REGISTER_API("relay.op.nn._make.lrn")
});

RELAY_REGISTER_OP("nn.lrn")
.describe(R"code(LRN layer.
.describe(R"code(LRN layer.
Normalize the input in a local region across or within feature maps.
Each input value is divided by (1 + (\alpha/n) \sum_i x_i^2)^\beta,
Expand All @@ -362,7 +362,7 @@ TVM_REGISTER_NODE_TYPE(L2NormalizeAttrs);

Expr MakeL2Normalize(Expr data,
double eps,
Array<IndexExpr> axis) {
Array<Integer> axis) {
auto attrs = make_node<L2NormalizeAttrs>();
attrs->eps = eps;
attrs->axis = std::move(axis);
Expand Down
23 changes: 11 additions & 12 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,24 +218,23 @@ bool TransposeRel(const Array<Type>& types,
}
const auto* param = attrs.as<TransposeAttrs>();
const int ndim = data->shape.size();
const Array<IndexExpr>& axes = param->axes;
const Array<Integer>& axes = param->axes;
// check dimension match
CHECK(axes.empty() || static_cast<int>(axes.size()) == ndim)
CHECK(!axes.defined() || static_cast<int>(axes.size()) == ndim)
<< "Dimension mismatch: axes has " << axes.size() << " elements"
<< ", but data.ndim = " << ndim;
// construct int_axes
std::vector<int> int_axes;
int_axes.reserve(ndim);
if (axes.empty()) {
// used not defined to check if it is None.
if (!axes.defined()) {
for (int i = ndim - 1; i >= 0; --i) {
int_axes.push_back(i);
}
} else {
std::vector<int> axis_used(ndim, 0);
for (const IndexExpr& e : axes) {
const int64_t *axis_ptr = as_const_int(e);
CHECK(axis_ptr != nullptr);
int axis = *axis_ptr;
for (const Integer& e : axes) {
int64_t axis = e;
// sanity check for axis and ndim
CHECK(-ndim <= axis && axis < ndim)
<< "transpose only allows each `axis` in `axes` in range [-data.ndim, data.ndim)"
Expand All @@ -245,7 +244,7 @@ bool TransposeRel(const Array<Type>& types,
// sanity check for duplication
CHECK(!axis_used[axis]) << "Duplicate axes in transpose: " << axis;
axis_used[axis] = 1;
int_axes.push_back(axis);
int_axes.push_back(static_cast<int>(axis));
}
}
std::vector<IndexExpr> oshape;
Expand All @@ -258,7 +257,7 @@ bool TransposeRel(const Array<Type>& types,
}

Expr MakeTranspose(Expr data,
Array<IndexExpr> axes) {
Array<Integer> axes) {
auto attrs = make_node<TransposeAttrs>();
attrs->axes = std::move(axes);
static const Op& op = Op::Get("transpose");
Expand Down Expand Up @@ -401,7 +400,7 @@ bool TakeRel(const Array<Type>& types,
std::vector<IndexExpr> oshape;
const auto ndim_data = static_cast<int>(data->shape.size());
const auto ndim_indices = static_cast<int>(indices->shape.size());
auto axis = (*as_const_int(param->axis));
int axis = static_cast<int>(param->axis->value);
if (axis < 0) axis += ndim_data;
CHECK_LE(axis, ndim_data)
<< "axis should be with in data shape"
Expand All @@ -424,9 +423,9 @@ bool TakeRel(const Array<Type>& types,

Expr MakeTake(Expr data,
Expr indices,
IndexExpr axis) {
Integer axis) {
auto attrs = make_node<TakeAttrs>();
attrs->axis = axis;
attrs->axis = std::move(axis);
static const Op& op = Op::Get("take");
return CallNode::make(op, {data, indices}, Attrs(attrs), {});
}
Expand Down

0 comments on commit a71b34d

Please sign in to comment.