Skip to content

Commit

Permalink
[REFACTOR][IR] attrs.h -> ir (apache#4709)
Browse files Browse the repository at this point in the history
This PR moves attrs.h into the ir folder as it
can serve as a common infra for building ir dats structures.

We also moved common container(FloatImm) into ir/expr.h
  • Loading branch information
tqchen authored and alexwong committed Feb 26, 2020
1 parent 0e2f40a commit ee345c1
Show file tree
Hide file tree
Showing 41 changed files with 174 additions and 123 deletions.
4 changes: 2 additions & 2 deletions include/tvm/expr_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -677,13 +677,13 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value) {
return LargeUIntImm(t, static_cast<int64_t>(low), static_cast<int64_t>(high));
}
}
if (t.is_float()) return ir::FloatImmNode::make(t, static_cast<double>(value));
if (t.is_float()) return FloatImm(t, static_cast<double>(value));
// For now, we store const scalar values of custom datatypes within doubles; later, during the
// datatypes lowering pass, we will lower the value to its true representation in the format
// specified by the datatype.
// TODO(gus) when do we need to start worrying about doubles not being precise enough?
if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(kCustomBegin)) {
return ir::FloatImmNode::make(t, static_cast<double>(value));
return FloatImm(t, static_cast<double>(value));
}
LOG(FATAL) << "cannot make const for type " << t;
return PrimExpr();
Expand Down
18 changes: 1 addition & 17 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,25 +37,9 @@ namespace tvm {
namespace ir {

using IntImmNode = tvm::IntImmNode;
using FloatImmNode = tvm::FloatImmNode;
using VarNode = tvm::VarNode;

/*! \brief Floating point constants. */
class FloatImmNode : public PrimExprNode {
public:
/*! \brief The constant value content. */
double value;

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("value", &value);
}

TVM_DLL static PrimExpr make(DataType t, double value);

static constexpr const char* _type_key = "FloatImm";
TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode);
};

/*! \brief String constants, only used in asserts. */
class StringImmNode : public PrimExprNode {
public:
Expand Down
53 changes: 21 additions & 32 deletions include/tvm/attrs.h → include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/attrs.h
* \brief TVM attribute module
* \file tvm/ir/attrs.h
* \brief Helpers for attribute objects.
*
* This module enables declaration of named attributes
* which support default value setup and bound checking.
Expand All @@ -42,31 +41,30 @@
*
* \sa AttrsNode, TVM_DECLARE_ATTRS, TVM_ATTR_FIELD
*/
#ifndef TVM_ATTRS_H_
#define TVM_ATTRS_H_
#ifndef TVM_IR_ATTRS_H_
#define TVM_IR_ATTRS_H_

#include <dmlc/common.h>
#include <tvm/ir/expr.h>
#include <tvm/runtime/packed_func.h>

#include <unordered_map>
#include <vector>
#include <functional>
#include <type_traits>
#include <string>
#include <utility>
#include "ir.h"
#include "base.h"
#include "expr.h"
#include "packed_func_ext.h"

namespace tvm {
/*!
* \brief Declare an attribute function.
* \param ClassName The name of the class.
* \param TypeKey The type key to be used by the TVM node system.
*/
#define TVM_DECLARE_ATTRS(ClassName, TypeKey) \
static constexpr const char* _type_key = TypeKey; \
#define TVM_DECLARE_ATTRS(ClassName, TypeKey) \
static constexpr const char* _type_key = TypeKey; \
TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \
template<typename FVisit> \
template<typename FVisit> \
void __VisitAttrs__(FVisit& __fvisit__) // NOLINT(*)


Expand Down Expand Up @@ -481,45 +479,36 @@ template<typename T>
inline void SetValue(T* ptr, const TVMArgValue& val) {
*ptr = val.operator T();
}

template<typename T>
inline void SetIntValue(T* ptr, const TVMArgValue& val) {
if (val.type_code() == kDLInt) {
*ptr = static_cast<T>(val.value().v_int64);
} else {
PrimExpr expr = val;
CHECK(expr.defined());
if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
*ptr = static_cast<T>(op->value);
} else {
LOG(FATAL) << "Expect int value, but get " << expr->GetTypeKey();
}
IntImm expr = val;
*ptr = static_cast<T>(expr->value);
}
}

template<>
inline void SetValue<std::string>(std::string* ptr, const TVMArgValue& val) {
if (val.type_code() == kStr) {
*ptr = val.operator std::string();
} else {
PrimExpr expr = val;
const ir::StringImmNode* op = expr.as<ir::StringImmNode>();
CHECK(op != nullptr);
*ptr = op->value;
LOG(FATAL) << "Expect str";
}
}
template<>
inline void SetValue(DataType* ptr, const TVMArgValue& val) {
*ptr = val.operator DataType();
}

template<>
inline void SetValue<double>(double* ptr, const TVMArgValue& val) {
if (val.type_code() == kDLFloat || val.type_code() == kDLInt) {
*ptr = val.operator double();
} else {
PrimExpr expr = val;
ObjectRef expr = val;
CHECK(expr.defined());
if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
if (const IntImmNode* op = expr.as<IntImmNode>()) {
*ptr = static_cast<double>(op->value);
} else if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
} else if (const FloatImmNode* op = expr.as<FloatImmNode>()) {
*ptr = static_cast<double>(op->value);
} else {
LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey();
Expand Down Expand Up @@ -611,7 +600,7 @@ struct TypeName<uint64_t> {

template<>
struct TypeName<DataType> {
static constexpr const char* value = "Type";
static constexpr const char* value = "DataType";
};

template<>
Expand Down Expand Up @@ -872,4 +861,4 @@ inline void BaseAttrsNode::PrintDocString(std::ostream &os) const { // NOLINT(*)
}

} // namespace tvm
#endif // TVM_ATTRS_H_
#endif // TVM_IR_ATTRS_H_
50 changes: 50 additions & 0 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,56 @@ class IntImm : public PrimExpr {
using ContainerType = IntImmNode;
};

/*!
* \brief Constant floating point literals in the program.
* \sa FloatImm
*/
class FloatImmNode : public PrimExprNode {
public:
/*! \brief The constant value content. */
double value;

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("value", &value);
}

static constexpr const char* _type_key = "FloatImm";
TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode);
};

/*!
* \brief Managed reference class to FloatImmNode.
*
* \sa FloatImmNode
*/
class FloatImm : public PrimExpr {
public:
/*!
* \brief Constructor
*/
FloatImm() {}
/*!
* \brief constructor from node.
*/
explicit FloatImm(ObjectPtr<Object> node) : PrimExpr(node) {}
/*!
* \brief Constructor.
* \param dtype The data type of the value.
* \param value The internal value.
*/
TVM_DLL FloatImm(DataType dtype, double value);
/*!
* \brief Get pointer to the container.
* \return The pointer.
*/
const FloatImmNode* operator->() const {
return static_cast<const FloatImmNode*>(get());
}
/*! \brief type indicate the container type */
using ContainerType = FloatImmNode;
};

/*!
* \brief Base node of all non-primitive expressions.
*
Expand Down
14 changes: 8 additions & 6 deletions include/tvm/ir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
#define TVM_IR_OP_H_

#include <dmlc/registry.h>
#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <tvm/runtime/registry.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/type.h>
Expand Down Expand Up @@ -296,7 +296,8 @@ class OpRegistry {
// return internal pointer to op.
inline OpNode* get();
// update the attribute OpMap
TVM_DLL void UpdateAttr(const std::string& key, TVMRetValue value,
TVM_DLL void UpdateAttr(const std::string& key,
runtime::TVMRetValue value,
int plevel);
};

Expand All @@ -316,7 +317,7 @@ class GenericOpMap {
* \param op The key to the map
* \return the const reference to the content value.
*/
inline const TVMRetValue& operator[](const Op& op) const;
inline const runtime::TVMRetValue& operator[](const Op& op) const;
/*!
* \brief get the corresponding value element at op with default value.
* \param op The key to the map
Expand All @@ -342,7 +343,7 @@ class GenericOpMap {
// the attribute field.
std::string attr_name_;
// internal data
std::vector<std::pair<TVMRetValue, int> > data_;
std::vector<std::pair<runtime::TVMRetValue, int> > data_;
// The value
GenericOpMap() = default;
};
Expand Down Expand Up @@ -532,7 +533,7 @@ template <typename ValueType>
inline OpRegistry& OpRegistry::set_attr( // NOLINT(*)
const std::string& attr_name, const ValueType& value, int plevel) {
CHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0";
TVMRetValue rv;
runtime::TVMRetValue rv;
rv = value;
UpdateAttr(attr_name, rv, plevel);
return *this;
Expand All @@ -548,7 +549,8 @@ inline int GenericOpMap::count(const Op& op) const {
}
}

inline const TVMRetValue& GenericOpMap::operator[](const Op& op) const {
inline const runtime::TVMRetValue&
GenericOpMap::operator[](const Op& op) const {
CHECK(op.defined());
const uint32_t idx = op->index_;
CHECK(idx < data_.size() && data_[idx].second != 0)
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/ir/type_relation.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#include <tvm/ir/type.h>
#include <tvm/ir/module.h>
#include <tvm/ir/env_func.h>
#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>

namespace tvm {

Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/adt.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ADT_H_
#define TVM_RELAY_ADT_H_

#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <tvm/ir/adt.h>
#include <string>
#include <functional>
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/algorithm.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_ALGORITHM_H_
#define TVM_RELAY_ATTRS_ALGORITHM_H_

#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <tvm/relay/base.h>
#include <string>

Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/annotation.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_ANNOTATION_H_
#define TVM_RELAY_ATTRS_ANNOTATION_H_

#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <string>

namespace tvm {
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/bitserial.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
#ifndef TVM_RELAY_ATTRS_BITSERIAL_H_
#define TVM_RELAY_ATTRS_BITSERIAL_H_

#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <tvm/relay/base.h>
#include <string>

Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/debug.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_DEBUG_H_
#define TVM_RELAY_ATTRS_DEBUG_H_

#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <string>

namespace tvm {
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/device_copy.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_DEVICE_COPY_H_
#define TVM_RELAY_ATTRS_DEVICE_COPY_H_

#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <string>

namespace tvm {
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/image.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_IMAGE_H_
#define TVM_RELAY_ATTRS_IMAGE_H_

#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <tvm/relay/base.h>
#include <string>

Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_MEMORY_H_
#define TVM_RELAY_ATTRS_MEMORY_H_

#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <tvm/relay/expr.h>
#include <string>

Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_NN_H_
#define TVM_RELAY_ATTRS_NN_H_

#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <tvm/relay/base.h>
#include <string>

Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_REDUCE_H_
#define TVM_RELAY_ATTRS_REDUCE_H_

#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <string>

namespace tvm {
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_TRANSFORM_H_
#define TVM_RELAY_ATTRS_TRANSFORM_H_

#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <tvm/relay/base.h>
#include <tvm/relay/expr.h>
#include <string>
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_VISION_H_
#define TVM_RELAY_ATTRS_VISION_H_

#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <tvm/relay/base.h>
#include <string>

Expand Down
Loading

0 comments on commit ee345c1

Please sign in to comment.