Skip to content

Commit

Permalink
[REFACTOR][IR] Polish ir/type (apache#4705)
Browse files Browse the repository at this point in the history
- Use consistent constructor style to construct objects.
- Move env_func to ir as it is mainly used to construct IRs.
- Make docs consistent.
  • Loading branch information
tqchen authored and zhiics committed Mar 2, 2020
1 parent 1818c03 commit 96dc161
Show file tree
Hide file tree
Showing 31 changed files with 415 additions and 273 deletions.
2 changes: 1 addition & 1 deletion include/tvm/ir/adt.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/

/*!
* \file tvm/relay/adt.h
* \file tvm/ir/adt.h
* \brief Algebraic data type definitions.
*
* We adopt relay's ADT definition as a unified class
Expand Down
24 changes: 12 additions & 12 deletions include/tvm/node/env_func.h → include/tvm/ir/env_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,24 @@
*/

/*!
* \file tvm/node/env_func.h
* \brief Serializable global function.
* \file tvm/ir/env_func.h
* \brief Serializable global function used in IR.
*/
#ifndef TVM_NODE_ENV_FUNC_H_
#define TVM_NODE_ENV_FUNC_H_
#ifndef TVM_IR_ENV_FUNC_H_
#define TVM_IR_ENV_FUNC_H_

#include <tvm/node/reflection.h>

#include <string>
#include <utility>


namespace tvm {
/*!
* \brief Node container of EnvFunc
* \brief A serializable function backed by TVM's global environment.
*
* This is a wrapper to enable serializable global PackedFunc.
* An EnvFunc is saved by its name in the global registry
* under the assumption that the same function is registered during load.
* \sa EnvFunc
*/
class EnvFuncNode : public Object {
Expand All @@ -53,11 +56,8 @@ class EnvFuncNode : public Object {
};

/*!
* \brief A serializable function backed by TVM's global environment.
*
* This is a wrapper to enable serializable global PackedFunc.
* An EnvFunc is saved by its name in the global registry
* under the assumption that the same function is registered during load.
* \brief Managed reference to EnvFuncNode.
* \sa EnvFuncNode
*/
class EnvFunc : public ObjectRef {
public:
Expand Down Expand Up @@ -140,4 +140,4 @@ class TypedEnvFunc<R(Args...)> : public ObjectRef {
};

} // namespace tvm
#endif // TVM_NODE_ENV_FUNC_H_
#endif // TVM_IR_ENV_FUNC_H_
2 changes: 1 addition & 1 deletion include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ class GlobalVarNode : public RelayExprNode {
/*! \brief The name of the variable, this only acts as a hint. */
std::string name_hint;

void VisitAttrs(tvm::AttrVisitor* v) {
void VisitAttrs(AttrVisitor* v) {
v->Visit("name_hint", &name_hint);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
Expand Down
22 changes: 11 additions & 11 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ class IRModule;
class IRModuleNode : public Object {
public:
/*! \brief A map from ids to all global functions. */
tvm::Map<GlobalVar, BaseFunc> functions;
Map<GlobalVar, BaseFunc> functions;
/*! \brief A map from global type vars to ADT type data. */
tvm::Map<GlobalTypeVar, TypeData> type_definitions;
Map<GlobalTypeVar, TypeData> type_definitions;

IRModuleNode() {}

void VisitAttrs(tvm::AttrVisitor* v) {
void VisitAttrs(AttrVisitor* v) {
v->Visit("functions", &functions);
v->Visit("type_definitions", &type_definitions);
v->Visit("global_var_map_", &global_var_map_);
Expand Down Expand Up @@ -146,7 +146,7 @@ class IRModuleNode : public Object {
* \brief Collect all global vars defined in this module.
* \returns An array of global vars
*/
TVM_DLL tvm::Array<GlobalVar> GetGlobalVars() const;
TVM_DLL Array<GlobalVar> GetGlobalVars() const;

/*!
* \brief Look up a global function by its name.
Expand All @@ -159,7 +159,7 @@ class IRModuleNode : public Object {
* \brief Collect all global type vars defined in this module.
* \returns An array of global type vars
*/
TVM_DLL tvm::Array<GlobalTypeVar> GetGlobalTypeVars() const;
TVM_DLL Array<GlobalTypeVar> GetGlobalTypeVars() const;

/*!
* \brief Look up a global function by its variable.
Expand Down Expand Up @@ -235,12 +235,12 @@ class IRModuleNode : public Object {
/*! \brief A map from string names to global variables that
* ensures global uniqueness.
*/
tvm::Map<std::string, GlobalVar> global_var_map_;
Map<std::string, GlobalVar> global_var_map_;

/*! \brief A map from string names to global type variables (ADT names)
* that ensures global uniqueness.
*/
tvm::Map<std::string, GlobalTypeVar> global_type_var_map_;
Map<std::string, GlobalTypeVar> global_type_var_map_;

/*! \brief A map from constructor tags to constructor objects
* for convenient access
Expand All @@ -266,8 +266,8 @@ class IRModule : public ObjectRef {
* \param type_definitions Type definitions in the module.
* \param import_set Set of imported files in the module
*/
TVM_DLL explicit IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
tvm::Map<GlobalTypeVar, TypeData> type_definitions = {},
TVM_DLL explicit IRModule(Map<GlobalVar, BaseFunc> functions,
Map<GlobalTypeVar, TypeData> type_definitions = {},
std::unordered_set<std::string> import_set = {});
/*! \brief default constructor */
IRModule() {}
Expand Down Expand Up @@ -296,8 +296,8 @@ class IRModule : public ObjectRef {
*/
TVM_DLL static IRModule FromExpr(
const RelayExpr& expr,
const tvm::Map<GlobalVar, BaseFunc>& global_funcs = {},
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions = {});
const Map<GlobalVar, BaseFunc>& global_funcs = {},
const Map<GlobalTypeVar, TypeData>& type_definitions = {});

/*!
* \brief Parse text format source file into an IRModule.
Expand Down
10 changes: 5 additions & 5 deletions include/tvm/ir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class OpNode : public RelayExprNode {
*/
int32_t support_level = 10;

void VisitAttrs(tvm::AttrVisitor* v) {
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("op_type", &op_type);
v->Visit("description", &description);
Expand Down Expand Up @@ -476,15 +476,15 @@ inline OpRegistry& OpRegistry::add_type_rel(
std::string input_name_prefix = "in";
for (int i = 0; i < get()->num_inputs; i++) {
auto name = input_name_prefix + std::to_string(i);
auto param = TypeVarNode::make(name, TypeKind::kType);
auto param = TypeVar(name, TypeKind::kType);
type_params.push_back(param);
arg_types.push_back(param);
}

Array<Type> ty_call_args = arg_types;

// Add output type.
auto out_param = TypeVarNode::make("out", TypeKind::kType);
auto out_param = TypeVar("out", TypeKind::kType);
type_params.push_back(out_param);
// this will trigger copy on write.
ty_call_args.push_back(out_param);
Expand All @@ -498,13 +498,13 @@ inline OpRegistry& OpRegistry::add_type_rel(
// A common example is sum(x, axis), where the choice of axis
// can affect the type of the function.
TypeConstraint type_rel =
TypeRelationNode::make(env_type_rel_func,
TypeRelation(env_type_rel_func,
ty_call_args,
arg_types.size(),
Attrs());

auto func_type =
FuncTypeNode::make(arg_types, out_param, type_params, {type_rel});
FuncType(arg_types, out_param, type_params, {type_rel});

get()->op_type = func_type;

Expand Down
26 changes: 13 additions & 13 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,13 @@ class PassContextNode : public Object {
int fallback_device{static_cast<int>(kDLCPU)};

/*! \brief The list of required passes. */
tvm::Array<tvm::PrimExpr> required_pass;
Array<PrimExpr> required_pass;
/*! \brief The list of disabled passes. */
tvm::Array<tvm::PrimExpr> disabled_pass;
Array<PrimExpr> disabled_pass;

PassContextNode() = default;

void VisitAttrs(tvm::AttrVisitor* v) {
void VisitAttrs(AttrVisitor* v) {
v->Visit("opt_level", &opt_level);
v->Visit("fallback_device", &fallback_device);
v->Visit("required_pass", &required_pass);
Expand Down Expand Up @@ -118,7 +118,7 @@ class PassContextNode : public Object {
class PassContext : public ObjectRef {
public:
PassContext() {}
explicit PassContext(ObjectPtr<::tvm::Object> n) : ObjectRef(n) {}
explicit PassContext(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief const accessor.
* \return const access pointer.
Expand Down Expand Up @@ -158,7 +158,7 @@ class PassContext : public ObjectRef {

// Classes to get the Python `with` like syntax.
friend class Internal;
friend class tvm::With<PassContext>;
friend class With<PassContext>;
};

/*!
Expand All @@ -174,11 +174,11 @@ class PassInfoNode : public Object {
std::string name;

/*! \brief The passes that are required to perform the current pass. */
tvm::Array<tvm::PrimExpr> required;
Array<PrimExpr> required;

PassInfoNode() = default;

void VisitAttrs(tvm::AttrVisitor* v) {
void VisitAttrs(AttrVisitor* v) {
v->Visit("opt_level", &opt_level);
v->Visit("name", &name);
v->Visit("required", &required);
Expand All @@ -202,7 +202,7 @@ class PassInfo : public ObjectRef {
*/
TVM_DLL PassInfo(int opt_level,
std::string name,
tvm::Array<tvm::PrimExpr> required);
Array<PrimExpr> required);

TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode);
};
Expand Down Expand Up @@ -241,7 +241,7 @@ class PassNode : public Object {
virtual IRModule operator()(const IRModule& mod,
const PassContext& pass_ctx) const = 0;

void VisitAttrs(tvm::AttrVisitor* v) {}
void VisitAttrs(AttrVisitor* v) {}

static constexpr const char* _type_key = "relay.Pass";
TVM_DECLARE_BASE_OBJECT_INFO(PassNode, Object);
Expand Down Expand Up @@ -289,7 +289,7 @@ class Sequential : public Pass {
* \param passes The passes to apply.
* \param pass_info The pass metadata.
*/
TVM_DLL Sequential(tvm::Array<Pass> passes, PassInfo pass_info);
TVM_DLL Sequential(Array<Pass> passes, PassInfo pass_info);

/*!
* \brief The constructor of `Sequential`.
Expand All @@ -299,10 +299,10 @@ class Sequential : public Pass {
* This allows users to only provide a list of passes and execute them
* under a given context.
*/
TVM_DLL Sequential(tvm::Array<Pass> passes, std::string name = "sequential");
TVM_DLL Sequential(Array<Pass> passes, std::string name = "sequential");

Sequential() = default;
explicit Sequential(tvm::ObjectPtr<::tvm::Object> n) : Pass(n) {}
explicit Sequential(ObjectPtr<Object> n) : Pass(n) {}

const SequentialNode* operator->() const;
using ContainerType = Sequential;
Expand All @@ -322,7 +322,7 @@ Pass CreateModulePass(
const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::PrimExpr>& required);
const Array<PrimExpr>& required);

} // namespace transform
} // namespace tvm
Expand Down
Loading

0 comments on commit 96dc161

Please sign in to comment.