Skip to content

Commit

Permalink
[REFACTOR][IR] Unify IntImm and UIntImm (apache#4706)
Browse files Browse the repository at this point in the history
* [REFACTOR][IR] Unify IntImm and UIntImm

This PR unifies UIntImm and IntImm to simplify the codebase.
Unsigned integer constants will also be stored as IntImm.

For uint constant that does not fit into int64(rare case), we introduced
an intrinsic tvm_big_uint_imm to construct such intgers by its
lower and higher 32bits.

* [REFACTOR][IR] Remove UIntImm to use IntImm

* rename big->large
  • Loading branch information
tqchen authored and alexwong committed Feb 26, 2020
1 parent 9da8a47 commit 8b6aeaa
Show file tree
Hide file tree
Showing 74 changed files with 361 additions and 413 deletions.
4 changes: 0 additions & 4 deletions include/tvm/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -490,8 +490,6 @@ inline void SetIntValue(T* ptr, const TVMArgValue& val) {
CHECK(expr.defined());
if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
*ptr = static_cast<T>(op->value);
} else if (const ir::UIntImmNode* op = expr.as<ir::UIntImmNode>()) {
*ptr = static_cast<T>(op->value);
} else {
LOG(FATAL) << "Expect int value, but get " << expr->GetTypeKey();
}
Expand Down Expand Up @@ -523,8 +521,6 @@ inline void SetValue<double>(double* ptr, const TVMArgValue& val) {
*ptr = static_cast<double>(op->value);
} else if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
*ptr = static_cast<double>(op->value);
} else if (const ir::UIntImmNode* op = expr.as<ir::UIntImmNode>()) {
*ptr = static_cast<double>(op->value);
} else {
LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey();
}
Expand Down
48 changes: 14 additions & 34 deletions include/tvm/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,56 +115,38 @@ class Var : public PrimExpr {
using ContainerType = VarNode;
};

class Integer;
/*! \brief ExprNode: constant integer. */
class IntImmNode : public PrimExprNode {
public:
/*! \brief the Internal value. */
int64_t value;

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("value", &value);
}

TVM_DLL static Integer make(DataType t, int64_t value);

static constexpr const char* _type_key = "IntImm";
TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode);
};

/*!
* \brief Container of constant integer (IntImm).
* \brief Container of constant int that adds more constructors.
*
* This is used to store and automate type check
* attributes that must be constant integer.
*
* \sa IntImm
*/
class Integer : public PrimExpr {
class Integer : public IntImm {
public:
Integer() : PrimExpr() {}
Integer() {}
/*!
* \brief constructor from node.
*/
explicit Integer(ObjectPtr<Object> node) : PrimExpr(node) {}
explicit Integer(ObjectPtr<Object> node) : IntImm(node) {}
/*!
* \brief Construct integer from int value.
*/
Integer(int value) : PrimExpr(value) {} // NOLINT(*)
Integer(int value) : IntImm(DataType::Int(32), value) {} // NOLINT(*)
/*!
* \brief Construct integer from int imm.
* \param other The other value.
*/
Integer(IntImm other) : IntImm(std::move(other)) {} // NOLINT(*)
/*!
* \brief Assign an expression to integer.
* \param other another expression.
*/
Integer& operator=(const Integer& other) {
data_ = other.data_;
Integer& operator=(const IntImm& other) {
data_ = ObjectRef::GetDataPtr<Object>(other);
return *this;
}
/*!
* \brief Get pointer to the internal value.
* \return the content of the integer.
*/
const IntImmNode* operator->() const {
return static_cast<const IntImmNode*>(get());
}
/*!
* \brief convert to int64_t
*/
Expand All @@ -173,8 +155,6 @@ class Integer : public PrimExpr {
<< " Trying to reference a null Integer";
return (*this)->value;
}
/*! \brief type indicate the container type */
using ContainerType = IntImmNode;
};

/*! \brief range over one dimension */
Expand Down
53 changes: 27 additions & 26 deletions include/tvm/expr_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

#include <algorithm>
#include <type_traits>
#include <limits>
#include "expr.h"
#include "ir.h"

Expand Down Expand Up @@ -82,21 +83,6 @@ inline const int64_t* as_const_int(const PrimExpr& x) {
}
}

/*!
* \brief Get x as constant uint expression.
* \param x The expression
* \return the address to the int expression,
* return nullptr, if x is not UIntImm.
*/
inline const uint64_t* as_const_uint(const PrimExpr& x) {
if (!x.defined()) return nullptr;
if (const ir::UIntImmNode* op = x.as<ir::UIntImmNode>()) {
return &(op->value);
} else {
return nullptr;
}
}

/*!
* \brief Check whether x is a constant integer expression.
* \param x The input argument
Expand Down Expand Up @@ -597,6 +583,15 @@ TVM_DLL PrimExpr nearbyint(PrimExpr x);
*/
TVM_DLL PrimExpr trunc(PrimExpr x);

/*!
* \brief Construct a large uint constant by its low 32 bits and high 32bits.
* \param dtype The final data type.
* \param low The lower 32 bits.
* \param high The higher 32 bits.
* \return The constructed expression.
*/
TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high);

// Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline PrimExpr OpName(PrimExpr x) { \
Expand All @@ -617,11 +612,11 @@ TVM_DECLARE_INTRIN_UNARY(atan);

// Implementation details after this
inline bool is_const(const PrimExpr& x) {
if (x.as<ir::IntImmNode>() || x.as<ir::UIntImmNode>()) {
if (x.as<ir::IntImmNode>()) {
return true;
} else if (const auto* op = x.as<ir::BroadcastNode>()) {
const PrimExpr& val = op->value;
if (val.as<ir::IntImmNode>() || val.as<ir::UIntImmNode>()) {
if (val.as<ir::IntImmNode>()) {
return true;
}
}
Expand All @@ -631,8 +626,6 @@ inline bool is_const(const PrimExpr& x) {
inline bool is_positive_const(const PrimExpr& a) {
if (const ir::IntImmNode* op = a.as<ir::IntImmNode>()) {
return op->value > 0;
} else if (const ir::UIntImmNode* op = a.as<ir::UIntImmNode>()) {
return op->value > 0;
} else {
return false;
}
Expand All @@ -649,14 +642,10 @@ inline bool is_negative_const(const PrimExpr& a) {
inline bool is_const_int(const PrimExpr& x, int64_t value) {
if (const auto* op = x.as<ir::IntImmNode>()) {
return op->value == value;
} else if (const auto* op = x.as<ir::UIntImmNode>()) {
return op->value == static_cast<uint64_t>(value);
} else if (const auto* op = x.as<ir::BroadcastNode>()) {
const PrimExpr& val = op->value;
if (const auto* opv = val.as<ir::IntImmNode>()) {
return opv->value == value;
} else if (const auto* opv = val.as<ir::UIntImmNode>()) {
return opv->value == static_cast<uint64_t>(value);
}
}
return false;
Expand All @@ -675,15 +664,27 @@ inline bool is_no_op(const Stmt& stmt) {

template<typename ValueType>
inline PrimExpr MakeConstScalar(DataType t, ValueType value) {
if (t.is_int()) return ir::IntImmNode::make(t, static_cast<int64_t>(value));
if (t.is_uint()) return ir::UIntImmNode::make(t, static_cast<uint64_t>(value));
if (t.is_int()) return IntImm(t, static_cast<int64_t>(value));
if (t.is_uint()) {
// Use IntImm if it is a small integer
uint64_t uval = static_cast<uint64_t>(value);
if (uval <= static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) {
return IntImm(t, static_cast<int64_t>(value));
} else {
uint64_t mask = (static_cast<uint64_t>(1) << 32U) - 1U;
uint64_t low = uval & mask;
uint64_t high = uval >> 32U;
return LargeUIntImm(t, static_cast<int64_t>(low), static_cast<int64_t>(high));
}
}
if (t.is_float()) return ir::FloatImmNode::make(t, static_cast<double>(value));
// For now, we store const scalar values of custom datatypes within doubles; later, during the
// datatypes lowering pass, we will lower the value to its true representation in the format
// specified by the datatype.
// TODO(gus) when do we need to start worrying about doubles not being precise enough?
if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(kCustomBegin))
if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(kCustomBegin)) {
return ir::FloatImmNode::make(t, static_cast<double>(value));
}
LOG(FATAL) << "cannot make const for type " << t;
return PrimExpr();
}
Expand Down
27 changes: 10 additions & 17 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,6 @@ namespace ir {
using IntImmNode = tvm::IntImmNode;
using VarNode = tvm::VarNode;

/*! \brief constant unsigned integer. */
class UIntImmNode : public PrimExprNode {
public:
/*! \brief The constant value content. */
uint64_t value;

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("value", &value);
}

TVM_DLL static PrimExpr make(DataType t, uint64_t value);

static constexpr const char* _type_key = "UIntImm";
TVM_DECLARE_FINAL_OBJECT_INFO(UIntImmNode, PrimExprNode);
};

/*! \brief Floating point constants. */
class FloatImmNode : public PrimExprNode {
public:
Expand Down Expand Up @@ -1422,6 +1405,16 @@ inline bool IsPragmaKey(const std::string& attr_key) {

/*! \brief namespace of TVM Intrinsic functions */
namespace intrinsic {
/*!
* \brief See pesudo code
*
* Construct a big uint that may not be representable by int64
*
* Expr tvm_large_uint_imm(uint32_t v0, uin32_t v1) {
* return (v1 << 32) | v0;
* }
*/
constexpr const char* tvm_large_uint_imm = "tvm_large_uint_imm";
/*!
* \brief See pesudo code
*
Expand Down
50 changes: 50 additions & 0 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,56 @@ class PrimExpr : public BaseExpr {
using ContainerType = PrimExprNode;
};

/*!
* \brief Constant integer literals in the program.
* \sa IntImm
*/
class IntImmNode : public PrimExprNode {
public:
/*! \brief the Internal value. */
int64_t value;

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("value", &value);
}

static constexpr const char* _type_key = "IntImm";
TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode);
};

/*!
* \brief Managed reference class to IntImmNode.
*
* \sa IntImmNode
*/
class IntImm : public PrimExpr {
public:
/*!
* \brief Constructor
*/
IntImm() {}
/*!
* \brief constructor from node.
*/
explicit IntImm(ObjectPtr<Object> node) : PrimExpr(node) {}
/*!
* \brief Constructor.
* \param dtype The data type of the value.
* \param value The internal value.
*/
TVM_DLL IntImm(DataType dtype, int64_t value);
/*!
* \brief Get pointer to the internal value.
* \return the content of the integer.
*/
const IntImmNode* operator->() const {
return static_cast<const IntImmNode*>(get());
}
/*! \brief type indicate the container type */
using ContainerType = IntImmNode;
};

/*!
* \brief Base node of all non-primitive expressions.
*
Expand Down
4 changes: 0 additions & 4 deletions include/tvm/ir_functor_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
virtual R VisitExpr_(const BroadcastNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const ShuffleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const IntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const UIntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const FloatImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Object* op, Args ...) {
Expand Down Expand Up @@ -203,7 +202,6 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
IR_EXPR_FUNCTOR_DISPATCH(ShuffleNode);
IR_EXPR_FUNCTOR_DISPATCH(BroadcastNode);
IR_EXPR_FUNCTOR_DISPATCH(IntImmNode);
IR_EXPR_FUNCTOR_DISPATCH(UIntImmNode);
IR_EXPR_FUNCTOR_DISPATCH(FloatImmNode);
IR_EXPR_FUNCTOR_DISPATCH(StringImmNode);
return vtable;
Expand Down Expand Up @@ -327,7 +325,6 @@ class TVM_DLL ExprVisitor :
void VisitExpr_(const BroadcastNode* op) override;
void VisitExpr_(const ShuffleNode* op) override;
void VisitExpr_(const IntImmNode* op) override;
void VisitExpr_(const UIntImmNode* op) override;
void VisitExpr_(const FloatImmNode* op) override;
void VisitExpr_(const StringImmNode* op) override;
};
Expand Down Expand Up @@ -372,7 +369,6 @@ class TVM_DLL ExprMutator :
PrimExpr VisitExpr_(const BroadcastNode* op) override;
PrimExpr VisitExpr_(const ShuffleNode* op) override;
PrimExpr VisitExpr_(const IntImmNode* op) override;
PrimExpr VisitExpr_(const UIntImmNode* op) override;
PrimExpr VisitExpr_(const FloatImmNode* op) override;
PrimExpr VisitExpr_(const StringImmNode* op) override;
};
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def const(value, dtype=None):
"""
if dtype is None:
dtype = _scalar_type_inference(value)
if dtype == "uint64" and value >= (1 << 63):
return _api_internal._LargeUIntImm(
dtype, value & ((1 << 32) - 1), value >> 32)
return _api_internal._const(value, dtype)


Expand Down
4 changes: 2 additions & 2 deletions python/tvm/autotvm/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def args_to_workload(x, topi_compute_func=None):
workload = tuple([args_to_workload(a) for a in x])
elif isinstance(x, (str, int, float, np.int, np.float, expr.Var)):
workload = x
elif isinstance(x, (expr.StringImm, expr.UIntImm, expr.IntImm, expr.FloatImm)):
elif isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)):
workload = x.value
elif x is None:
workload = 0
Expand Down Expand Up @@ -344,7 +344,7 @@ def _count_flop(exp):
if len(source) != 1:
raise FlopCalculationError("Found multiple output in the source of reduce op")
return num_iter * (_count_flop(combiner[0]) + _count_flop(source[0]))
if isinstance(exp, (expr.FloatImm, expr.IntImm, expr.UIntImm)):
if isinstance(exp, (expr.FloatImm, expr.IntImm)):
return 0
if isinstance(exp, expr.Cast):
return _count_flop(exp.value)
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/autotvm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,9 @@ def get_const_int(exp):
"""
if isinstance(exp, int):
return exp
if not isinstance(exp, (expr.IntImm, expr.UIntImm)):
if not isinstance(exp, (expr.IntImm,)):
exp = ir_pass.Simplify(exp)
if not isinstance(exp, (expr.IntImm, expr.UIntImm)):
if not isinstance(exp, (expr.IntImm,)):
raise ValueError("Expect value to be constant int")
return exp.value

Expand All @@ -179,9 +179,9 @@ def get_const_tuple(in_tuple):
for elem in in_tuple:
if isinstance(elem, expr.Var):
ret.append(elem)
elif not isinstance(elem, (expr.IntImm, expr.UIntImm, int)):
elif not isinstance(elem, (expr.IntImm, int)):
elem = ir_pass.Simplify(elem)
if not isinstance(elem, (expr.IntImm, expr.UIntImm)):
if not isinstance(elem, (expr.IntImm)):
ret.append(elem)
else:
ret.append(get_const_int(elem))
Expand Down
Loading

0 comments on commit 8b6aeaa

Please sign in to comment.