Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR][IR] Polish ir/type #4705

Merged
merged 1 commit into from
Jan 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/ir/adt.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/

/*!
* \file tvm/relay/adt.h
* \file tvm/ir/adt.h
* \brief Algebraic data type definitions.
*
* We adopt relay's ADT definition as a unified class
Expand Down
24 changes: 12 additions & 12 deletions include/tvm/node/env_func.h → include/tvm/ir/env_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,24 @@
*/

/*!
* \file tvm/node/env_func.h
* \brief Serializable global function.
* \file tvm/ir/env_func.h
* \brief Serializable global function used in IR.
*/
#ifndef TVM_NODE_ENV_FUNC_H_
#define TVM_NODE_ENV_FUNC_H_
#ifndef TVM_IR_ENV_FUNC_H_
#define TVM_IR_ENV_FUNC_H_

#include <tvm/node/reflection.h>

#include <string>
#include <utility>


namespace tvm {
/*!
* \brief Node container of EnvFunc
* \brief A serializable function backed by TVM's global environment.
*
* This is a wrapper to enable serializable global PackedFunc.
* An EnvFunc is saved by its name in the global registry
* under the assumption that the same function is registered during load.
* \sa EnvFunc
*/
class EnvFuncNode : public Object {
Expand All @@ -53,11 +56,8 @@ class EnvFuncNode : public Object {
};

/*!
* \brief A serializable function backed by TVM's global environment.
*
* This is a wrapper to enable serializable global PackedFunc.
* An EnvFunc is saved by its name in the global registry
* under the assumption that the same function is registered during load.
* \brief Managed reference to EnvFuncNode.
* \sa EnvFuncNode
*/
class EnvFunc : public ObjectRef {
public:
Expand Down Expand Up @@ -140,4 +140,4 @@ class TypedEnvFunc<R(Args...)> : public ObjectRef {
};

} // namespace tvm
#endif // TVM_NODE_ENV_FUNC_H_
#endif // TVM_IR_ENV_FUNC_H_
2 changes: 1 addition & 1 deletion include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ class GlobalVarNode : public RelayExprNode {
/*! \brief The name of the variable, this only acts as a hint. */
std::string name_hint;

void VisitAttrs(tvm::AttrVisitor* v) {
void VisitAttrs(AttrVisitor* v) {
v->Visit("name_hint", &name_hint);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
Expand Down
22 changes: 11 additions & 11 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ class IRModule;
class IRModuleNode : public Object {
public:
/*! \brief A map from ids to all global functions. */
tvm::Map<GlobalVar, BaseFunc> functions;
Map<GlobalVar, BaseFunc> functions;
/*! \brief A map from global type vars to ADT type data. */
tvm::Map<GlobalTypeVar, TypeData> type_definitions;
Map<GlobalTypeVar, TypeData> type_definitions;

IRModuleNode() {}

void VisitAttrs(tvm::AttrVisitor* v) {
void VisitAttrs(AttrVisitor* v) {
v->Visit("functions", &functions);
v->Visit("type_definitions", &type_definitions);
v->Visit("global_var_map_", &global_var_map_);
Expand Down Expand Up @@ -146,7 +146,7 @@ class IRModuleNode : public Object {
* \brief Collect all global vars defined in this module.
* \returns An array of global vars
*/
TVM_DLL tvm::Array<GlobalVar> GetGlobalVars() const;
TVM_DLL Array<GlobalVar> GetGlobalVars() const;

/*!
* \brief Look up a global function by its name.
Expand All @@ -159,7 +159,7 @@ class IRModuleNode : public Object {
* \brief Collect all global type vars defined in this module.
* \returns An array of global type vars
*/
TVM_DLL tvm::Array<GlobalTypeVar> GetGlobalTypeVars() const;
TVM_DLL Array<GlobalTypeVar> GetGlobalTypeVars() const;

/*!
* \brief Look up a global function by its variable.
Expand Down Expand Up @@ -235,12 +235,12 @@ class IRModuleNode : public Object {
/*! \brief A map from string names to global variables that
* ensures global uniqueness.
*/
tvm::Map<std::string, GlobalVar> global_var_map_;
Map<std::string, GlobalVar> global_var_map_;

/*! \brief A map from string names to global type variables (ADT names)
* that ensures global uniqueness.
*/
tvm::Map<std::string, GlobalTypeVar> global_type_var_map_;
Map<std::string, GlobalTypeVar> global_type_var_map_;

/*! \brief A map from constructor tags to constructor objects
* for convenient access
Expand All @@ -266,8 +266,8 @@ class IRModule : public ObjectRef {
* \param type_definitions Type definitions in the module.
* \param import_set Set of imported files in the module
*/
TVM_DLL explicit IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
tvm::Map<GlobalTypeVar, TypeData> type_definitions = {},
TVM_DLL explicit IRModule(Map<GlobalVar, BaseFunc> functions,
Map<GlobalTypeVar, TypeData> type_definitions = {},
std::unordered_set<std::string> import_set = {});
/*! \brief default constructor */
IRModule() {}
Expand Down Expand Up @@ -296,8 +296,8 @@ class IRModule : public ObjectRef {
*/
TVM_DLL static IRModule FromExpr(
const RelayExpr& expr,
const tvm::Map<GlobalVar, BaseFunc>& global_funcs = {},
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions = {});
const Map<GlobalVar, BaseFunc>& global_funcs = {},
const Map<GlobalTypeVar, TypeData>& type_definitions = {});

/*!
* \brief Parse text format source file into an IRModule.
Expand Down
10 changes: 5 additions & 5 deletions include/tvm/ir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class OpNode : public RelayExprNode {
*/
int32_t support_level = 10;

void VisitAttrs(tvm::AttrVisitor* v) {
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("op_type", &op_type);
v->Visit("description", &description);
Expand Down Expand Up @@ -476,15 +476,15 @@ inline OpRegistry& OpRegistry::add_type_rel(
std::string input_name_prefix = "in";
for (int i = 0; i < get()->num_inputs; i++) {
auto name = input_name_prefix + std::to_string(i);
auto param = TypeVarNode::make(name, TypeKind::kType);
auto param = TypeVar(name, TypeKind::kType);
type_params.push_back(param);
arg_types.push_back(param);
}

Array<Type> ty_call_args = arg_types;

// Add output type.
auto out_param = TypeVarNode::make("out", TypeKind::kType);
auto out_param = TypeVar("out", TypeKind::kType);
type_params.push_back(out_param);
// this will trigger copy on write.
ty_call_args.push_back(out_param);
Expand All @@ -498,13 +498,13 @@ inline OpRegistry& OpRegistry::add_type_rel(
// A common example is sum(x, axis), where the choice of axis
// can affect the type of the function.
TypeConstraint type_rel =
TypeRelationNode::make(env_type_rel_func,
TypeRelation(env_type_rel_func,
ty_call_args,
arg_types.size(),
Attrs());

auto func_type =
FuncTypeNode::make(arg_types, out_param, type_params, {type_rel});
FuncType(arg_types, out_param, type_params, {type_rel});

get()->op_type = func_type;

Expand Down
26 changes: 13 additions & 13 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,13 @@ class PassContextNode : public Object {
int fallback_device{static_cast<int>(kDLCPU)};

/*! \brief The list of required passes. */
tvm::Array<tvm::PrimExpr> required_pass;
Array<PrimExpr> required_pass;
/*! \brief The list of disabled passes. */
tvm::Array<tvm::PrimExpr> disabled_pass;
Array<PrimExpr> disabled_pass;

PassContextNode() = default;

void VisitAttrs(tvm::AttrVisitor* v) {
void VisitAttrs(AttrVisitor* v) {
v->Visit("opt_level", &opt_level);
v->Visit("fallback_device", &fallback_device);
v->Visit("required_pass", &required_pass);
Expand Down Expand Up @@ -118,7 +118,7 @@ class PassContextNode : public Object {
class PassContext : public ObjectRef {
public:
PassContext() {}
explicit PassContext(ObjectPtr<::tvm::Object> n) : ObjectRef(n) {}
explicit PassContext(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief const accessor.
* \return const access pointer.
Expand Down Expand Up @@ -158,7 +158,7 @@ class PassContext : public ObjectRef {

// Classes to get the Python `with` like syntax.
friend class Internal;
friend class tvm::With<PassContext>;
friend class With<PassContext>;
};

/*!
Expand All @@ -174,11 +174,11 @@ class PassInfoNode : public Object {
std::string name;

/*! \brief The passes that are required to perform the current pass. */
tvm::Array<tvm::PrimExpr> required;
Array<PrimExpr> required;

PassInfoNode() = default;

void VisitAttrs(tvm::AttrVisitor* v) {
void VisitAttrs(AttrVisitor* v) {
v->Visit("opt_level", &opt_level);
v->Visit("name", &name);
v->Visit("required", &required);
Expand All @@ -202,7 +202,7 @@ class PassInfo : public ObjectRef {
*/
TVM_DLL PassInfo(int opt_level,
std::string name,
tvm::Array<tvm::PrimExpr> required);
Array<PrimExpr> required);

TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode);
};
Expand Down Expand Up @@ -241,7 +241,7 @@ class PassNode : public Object {
virtual IRModule operator()(const IRModule& mod,
const PassContext& pass_ctx) const = 0;

void VisitAttrs(tvm::AttrVisitor* v) {}
void VisitAttrs(AttrVisitor* v) {}

static constexpr const char* _type_key = "relay.Pass";
TVM_DECLARE_BASE_OBJECT_INFO(PassNode, Object);
Expand Down Expand Up @@ -289,7 +289,7 @@ class Sequential : public Pass {
* \param passes The passes to apply.
* \param pass_info The pass metadata.
*/
TVM_DLL Sequential(tvm::Array<Pass> passes, PassInfo pass_info);
TVM_DLL Sequential(Array<Pass> passes, PassInfo pass_info);

/*!
* \brief The constructor of `Sequential`.
Expand All @@ -299,10 +299,10 @@ class Sequential : public Pass {
* This allows users to only provide a list of passes and execute them
* under a given context.
*/
TVM_DLL Sequential(tvm::Array<Pass> passes, std::string name = "sequential");
TVM_DLL Sequential(Array<Pass> passes, std::string name = "sequential");

Sequential() = default;
explicit Sequential(tvm::ObjectPtr<::tvm::Object> n) : Pass(n) {}
explicit Sequential(ObjectPtr<Object> n) : Pass(n) {}

const SequentialNode* operator->() const;
using ContainerType = Sequential;
Expand All @@ -322,7 +322,7 @@ Pass CreateModulePass(
const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::PrimExpr>& required);
const Array<PrimExpr>& required);

} // namespace transform
} // namespace tvm
Expand Down
Loading