Skip to content

Commit

Permalink
[NODE][REFACTOR] Refactor reflection system in node. (apache#4189)
Browse files Browse the repository at this point in the history
* [NODE][REFACTOR] Refactor reflection system in node.

- Removed the old Node, Node is now just an alias of runtime::Object
- Introduce ReflectionVTable, a new columnar dispatcher to support reflection
  - This allows us to remove vtable from most node objects
  - The VisitAttrs are registered via TVM_RESGITER_NODE_TYPE,
    they are no longer virtual.
- Consolidated serialization and reflection features into node.

* Explicit type qualification when calling destructor.

* Fix SPIRV, more comments
  • Loading branch information
tqchen authored and kevinthesun committed Oct 30, 2019
1 parent 5884ea9 commit 1a11efb
Show file tree
Hide file tree
Showing 76 changed files with 1,105 additions and 941 deletions.
2 changes: 1 addition & 1 deletion include/tvm/api_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class EnvFuncNode : public Node {
/*! \brief constructor */
EnvFuncNode() {}

void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
}

Expand Down
6 changes: 3 additions & 3 deletions include/tvm/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class ConstIntBoundNode : public Node {
int64_t min_value;
int64_t max_value;

void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("min_value", &min_value);
v->Visit("max_value", &max_value);
}
Expand Down Expand Up @@ -162,7 +162,7 @@ class ModularSetNode : public Node {
/*! \brief The base */
int64_t base;

void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("coeff", &coeff);
v->Visit("base", &base);
}
Expand Down Expand Up @@ -351,7 +351,7 @@ enum SignType {
*/
struct IntSetNode : public Node {
static constexpr const char* _type_key = "IntSet";
TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node);
TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Object);
};

/*!
Expand Down
10 changes: 6 additions & 4 deletions include/tvm/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class AttrFieldInfoNode : public Node {
/*! \brief detailed description of the type */
std::string description;

void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("type_info", &type_info);
v->Visit("description", &description);
Expand Down Expand Up @@ -197,7 +197,7 @@ class AttrsHash {
size_t operator()(const std::string& value) const {
return std::hash<std::string>()(value);
}
size_t operator()(const Type& value) const {
size_t operator()(const DataType& value) const {
return std::hash<int>()(
static_cast<int>(value.code()) |
(static_cast<int>(value.bits()) << 8) |
Expand All @@ -221,6 +221,8 @@ class BaseAttrsNode : public Node {
public:
using TVMArgs = runtime::TVMArgs;
using TVMRetValue = runtime::TVMRetValue;
// visit function
virtual void VisitAttrs(AttrVisitor* v) {}
/*!
* \brief Initialize the attributes by sequence of arguments
* \param args The postional arguments in the form
Expand Down Expand Up @@ -753,12 +755,12 @@ class AttrNonDefaultVisitor {
template<typename DerivedType>
class AttrsNode : public BaseAttrsNode {
public:
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
::tvm::detail::AttrNormalVisitor vis(v);
self()->__VisitAttrs__(vis);
}

void VisitNonDefaultAttrs(AttrVisitor* v) final {
void VisitNonDefaultAttrs(AttrVisitor* v) {
::tvm::detail::AttrNonDefaultVisitor vis(v);
self()->__VisitAttrs__(vis);
}
Expand Down
169 changes: 1 addition & 168 deletions include/tvm/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,89 +19,16 @@

/*!
* \file tvm/base.h
* \brief Defines the base data structure
* \brief Base utilities
*/
#ifndef TVM_BASE_H_
#define TVM_BASE_H_

#include <dmlc/logging.h>
#include <dmlc/registry.h>
#include <tvm/node/node.h>
#include <string>
#include <memory>
#include <functional>
#include <utility>
#include "runtime/registry.h"

namespace tvm {

using ::tvm::Node;
using ::tvm::NodeRef;
using ::tvm::AttrVisitor;

/*!
* \brief Macro to define common node ref methods.
* \param TypeName The name of the NodeRef.
* \param BaseTypeName The Base type.
* \param NodeName The node container type.
*/
#define TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseTypeName, NodeName) \
TypeName() {} \
explicit TypeName(::tvm::ObjectPtr<::tvm::Object> n) \
: BaseTypeName(n) {} \
const NodeName* operator->() const { \
return static_cast<const NodeName*>(data_.get()); \
} \
operator bool() const { return this->defined(); } \
using ContainerType = NodeName;

/*!
* \brief Macro to define CopyOnWrite function in a NodeRef.
* \param NodeName The Type of the Node.
*
* CopyOnWrite will generate a unique copy of the internal node.
* The node will be copied if it is referenced by multiple places.
* The function returns the raw pointer to the node to allow modification
* of the content.
*
* \code
*
* MyCOWNodeRef ref, ref2;
* ref2 = ref;
* ref.CopyOnWrite()->value = new_value;
* assert(ref2->value == old_value);
* assert(ref->value == new_value);
*
* \endcode
*/
#define TVM_DEFINE_NODE_REF_COW(NodeName) \
NodeName* CopyOnWrite() { \
CHECK(data_ != nullptr); \
if (!data_.unique()) { \
NodePtr<NodeName> n = make_node<NodeName>(*(operator->())); \
ObjectPtr<Object>(std::move(n)).swap(data_); \
} \
return static_cast<NodeName*>(data_.get()); \
}

/*! \brief Macro to make it easy to define node ref type given node */
#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \
class TypeName : public ::tvm::NodeRef { \
public: \
TVM_DEFINE_NODE_REF_METHODS(TypeName, ::tvm::NodeRef, NodeName); \
}; \

/*!
* \brief Macro to make it easy to define node ref type that
* has a CopyOnWrite member function.
*/
#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \
class TypeName : public BaseType { \
public: \
TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseType, NodeName); \
TVM_DEFINE_NODE_REF_COW(NodeName); \
};

/*!
* \brief RAII wrapper function to enter and exit a context object
* similar to python's with syntax.
Expand Down Expand Up @@ -146,100 +73,6 @@ class With {
ContextType ctx_;
};

/*!
* \brief save the node as well as all the node it depends on as json.
* This can be used to serialize any TVM object
*
* \return the string representation of the node.
*/
std::string SaveJSON(const NodeRef& node);

/*!
* \brief Internal implementation of LoadJSON
* Load tvm Node object from json and return a shared_ptr of Node.
* \param json_str The json string to load from.
*
* \return The shared_ptr of the Node.
*/
ObjectPtr<Object> LoadJSON_(std::string json_str);

/*!
* \brief Load the node from json string.
* This can be used to deserialize any TVM object.
*
* \param json_str The json string to load from.
*
* \tparam NodeType the nodetype
*
* \code
* Expr e = LoadJSON<Expr>(json_str);
* \endcode
*/
template<typename NodeType,
typename = typename std::enable_if<std::is_base_of<NodeRef, NodeType>::value>::type >
inline NodeType LoadJSON(const std::string& json_str) {
return NodeType(LoadJSON_(json_str));
}

/*!
* \brief Registry entry for NodeFactory.
*
* There are two types of Nodes that can be serialized.
* The normal node requires a registration a creator function that
* constructs an empty Node of the corresponding type.
*
* The global singleton(e.g. global operator) where only global_key need to be serialized,
* in this case, FGlobalKey need to be defined.
*/
struct NodeFactoryReg {
/*!
* \brief creator function.
* \param global_key Key that identifies a global single object.
* If this is not empty then FGlobalKey
* \return The created function.
*/
using FCreate = std::function<NodePtr<Node>(const std::string& global_key)>;
/*!
* \brief Global key function, only needed by global objects.
* \param node The node pointer.
* \return node The global key to the node.
*/
using FGlobalKey = std::function<std::string(const Node* node)>;
/*! \brief registered name */
std::string name;
/*!
* \brief The creator function
*/
FCreate fcreator = nullptr;
/*!
* \brief The global key function.
*/
FGlobalKey fglobal_key = nullptr;
// setter of creator
NodeFactoryReg& set_creator(FCreate f) { // NOLINT(*)
this->fcreator = f;
return *this;
}
// setter of creator
NodeFactoryReg& set_global_key(FGlobalKey f) { // NOLINT(*)
this->fglobal_key = f;
return *this;
}
// global registry singleton
TVM_DLL static ::dmlc::Registry<::tvm::NodeFactoryReg> *Registry();
};

/*!
* \brief Register a Node type
* \note This is necessary to enable serialization of the Node.
*/
#define TVM_REGISTER_NODE_TYPE(TypeName) \
TVM_REGISTER_OBJECT_TYPE(TypeName); \
static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \
::tvm::NodeFactoryReg::Registry()->__REGISTER__(TypeName::_type_key) \
.set_creator([](const std::string&) { return ::tvm::make_node<TypeName>(); })


#define TVM_STRINGIZE_DETAIL(x) #x
#define TVM_STRINGIZE(x) TVM_STRINGIZE_DETAIL(x)
#define TVM_DESCRIBE(...) describe(__VA_ARGS__ "\n\nFrom:" __FILE__ ":" TVM_STRINGIZE(__LINE__))
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class BufferNode : public Node {
/*! \brief constructor */
BufferNode() {}

void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("data", &data);
v->Visit("dtype", &dtype);
v->Visit("shape", &shape);
Expand Down
6 changes: 4 additions & 2 deletions include/tvm/build_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class TargetNode : public Node {
/*! \return the full device string to pass to codegen::Build */
TVM_DLL const std::string& str() const;

void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("target_name", &target_name);
v->Visit("device_name", &device_name);
v->Visit("device_type", &device_type);
Expand Down Expand Up @@ -229,7 +229,7 @@ class BuildConfigNode : public Node {
/*! \brief Whether to disable loop vectorization. */
bool disable_vectorize = false;

void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("data_alignment", &data_alignment);
v->Visit("offset_factor", &offset_factor);
v->Visit("double_buffer_split_loop", &double_buffer_split_loop);
Expand Down Expand Up @@ -473,6 +473,8 @@ class GenericFuncNode : public Node {
/* \brief map from keys to registered functions */
std::unordered_map<std::string, runtime::PackedFunc> dispatch_dict_;

void VisitAttrs(AttrVisitor* v) {}

static constexpr const char* _type_key = "GenericFunc";
TVM_DECLARE_NODE_TYPE_INFO(GenericFuncNode, Node);
};
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ struct ChannelNode : public Node {
/*! \brief default data type in read/write */
Type dtype;
// visit all attributes
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("handle_var", &handle_var);
v->Visit("dtype", &dtype);
}
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/data_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class LayoutNode : public Node {
*/
Array<IterVar> axes;

void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("axes", &axes);
}
Expand Down Expand Up @@ -325,7 +325,7 @@ class BijectiveLayoutNode : public Node {
/*! \brief The destination layout */
Layout dst_layout;

void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("src_layout", &src_layout);
v->Visit("dst_layout", &dst_layout);
v->Visit("forward_rule", &forward_rule);
Expand Down
12 changes: 7 additions & 5 deletions include/tvm/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
#include <string>
#include <algorithm>
#include <unordered_map>
#include <iostream>
#include "base.h"
#include "dtype.h"
#include "node/node.h"
#include "node/container.h"
#include "node/ir_functor.h"
#include "runtime/c_runtime_api.h"
Expand Down Expand Up @@ -110,7 +112,7 @@ class Variable : public ExprNode {

static Var make(DataType dtype, std::string name_hint);

void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type);
v->Visit("name", &name_hint);
}
Expand Down Expand Up @@ -164,7 +166,7 @@ class IntImm : public ExprNode {
/*! \brief the Internal value. */
int64_t value;

void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type);
v->Visit("value", &value);
}
Expand Down Expand Up @@ -230,7 +232,7 @@ class RangeNode : public Node {
RangeNode() {}
RangeNode(Expr min, Expr extent) : min(min), extent(extent) {}

void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("min", &min);
v->Visit("extent", &extent);
}
Expand Down Expand Up @@ -406,7 +408,7 @@ class IterVarNode : public Node {
*/
std::string thread_tag;

void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("dom", &dom);
v->Visit("var", &var);
v->Visit("iter_type", &iter_type);
Expand Down Expand Up @@ -490,7 +492,7 @@ class IRPrinter {
};

// default print function for all nodes
inline std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT(*)
inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) { // NOLINT(*)
IRPrinter(os).Print(n);
return os;
}
Expand Down
Loading

0 comments on commit 1a11efb

Please sign in to comment.