Skip to content

Commit

Permalink
[REFACTOR][IR] Unified IR IRModule structure. (#4699)
Browse files Browse the repository at this point in the history
This PR brings relay::Module as the unified IRModule structure.
IRModule will be used as the basic unit for transformations
through out the stack.

- Rename relay::Module -> IRModule
- Move relay/module.h -> ir/module.h
- ModuleNode::FromExpr -> IRModule::FromExpr
- FromText -> IRModule::FromText
  • Loading branch information
tqchen authored Jan 14, 2020
1 parent bd17baa commit c69092a
Show file tree
Hide file tree
Showing 60 changed files with 384 additions and 374 deletions.
2 changes: 1 addition & 1 deletion docs/dev/relay_pass_infra.rst
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ registration.
auto fx = relay::FunctionNode::make(tvm::Array<relay::Var>{ y }, call, relay::Type(), {});

// Create a module for optimization.
auto mod = relay::ModuleNode::FromExpr(fx);
auto mod = IRModule::FromExpr(fx);

// Create a sequential pass.
tvm::Array<relay::transform::Pass> pass_seqs{
Expand Down
145 changes: 73 additions & 72 deletions include/tvm/relay/module.h → include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,55 +18,41 @@
*/

/*!
* \file tvm/relay/module.h
* \brief The global environment: contains information needed to
* compile & optimize Relay programs.
* \file tvm/ir/module.h
* \brief IRModule that holds the functions and type definitions.
*/
#ifndef TVM_RELAY_MODULE_H_
#define TVM_RELAY_MODULE_H_

#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/adt.h>
#include <tvm/relay/op.h>
#include <tvm/relay/type.h>
#ifndef TVM_IR_MODULE_H_
#define TVM_IR_MODULE_H_

#include <tvm/ir/type.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/adt.h>

#include <string>
#include <vector>
#include <unordered_map>
#include <unordered_set>

namespace tvm {
namespace relay {

struct Module;

/*! \brief The global environment of Relay programs.
*
* The global environment contains the global
* information needed to compile a Relay program.
*
* It contains all global functions, and configuration
* options.
class IRModule;
/*!
* \brief IRModule that holds functions and type definitions.
*
* Many operations require access to the global
* Module. We pass the Module by value
* in a functional style as an explicit argument,
* but we mutate the Module while optimizing
* Relay programs.
* IRModule is the basic unit for all IR transformations across the stack.
*
* The functional style allows users to construct custom
* environments easily, for example each thread can store
* a Module while auto-tuning.
* Many operations require access to the global IRModule.
* We pass the IRModule by value in a functional style as an explicit argument,
* but we mutate the Module while optimizing programs.
* \sa IRModule
*/

class ModuleNode : public RelayNode {
class IRModuleNode : public Object {
public:
/*! \brief A map from ids to all global functions. */
tvm::Map<GlobalVar, BaseFunc> functions;
/*! \brief A map from global type vars to ADT type data. */
tvm::Map<GlobalTypeVar, TypeData> type_definitions;

ModuleNode() {}
IRModuleNode() {}

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("functions", &functions);
Expand All @@ -75,10 +61,6 @@ class ModuleNode : public RelayNode {
v->Visit("global_type_var_map_", &global_type_var_map_);
}

TVM_DLL static Module make(tvm::Map<GlobalVar, BaseFunc> global_funcs,
tvm::Map<GlobalTypeVar, TypeData> global_type_defs,
std::unordered_set<std::string> imports = {});

/*!
* \brief Add a function to the global environment.
* \param var The var of the global function.
Expand Down Expand Up @@ -219,7 +201,7 @@ class ModuleNode : public RelayNode {
* functions in another environment.
* \param other The other environment.
*/
TVM_DLL void Update(const Module& other);
TVM_DLL void Update(const IRModule& other);

/*!
* \brief Import Relay code from the file at path.
Expand All @@ -243,24 +225,8 @@ class ModuleNode : public RelayNode {
*/
TVM_DLL std::unordered_set<std::string> Imports() const;

/*! \brief Construct a module from a standalone expression.
*
* Allows one to optionally pass a global function map and
* map of type definitions as well.
*
* \param expr The expression to set as the main function to the module.
* \param global_funcs The global function map.
* \param type_definitions Map of global type definitions
*
* \returns A module with expr set as the main function.
*/
TVM_DLL static Module FromExpr(
const Expr& expr,
const tvm::Map<GlobalVar, BaseFunc>& global_funcs = {},
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions = {});

static constexpr const char* _type_key = "relay.Module";
TVM_DECLARE_FINAL_OBJECT_INFO(ModuleNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleNode, Object);

private:
/*! \brief Helper function for registering a typedef's constructors */
Expand All @@ -285,27 +251,62 @@ class ModuleNode : public RelayNode {
importing is idempotent for each module.
*/
std::unordered_set<std::string> import_set_;
friend class IRModule;
};

struct Module : public ObjectRef {
Module() {}
explicit Module(ObjectPtr<::tvm::Object> p) : ObjectRef(p) {}

ModuleNode* operator->() const {
return static_cast<ModuleNode*>(get_mutable());
/*!
* \brief Managed reference class to IRModuleNode.
* \sa IRModuleNode
*/
class IRModule : public ObjectRef {
public:
/*!
* \brief constructor
* \param functions Functions in the module.
* \param type_definitions Type definitions in the module.
* \param import_set Set of imported files in the module
*/
TVM_DLL explicit IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
tvm::Map<GlobalTypeVar, TypeData> type_definitions = {},
std::unordered_set<std::string> import_set = {});
/*! \brief default constructor */
IRModule() {}
/*!
* \brief constructor
* \param n The object pointer.
*/
explicit IRModule(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! \return mutable pointers to the node. */
IRModuleNode* operator->() const {
auto* ptr = get_mutable();
CHECK(ptr != nullptr);
return static_cast<IRModuleNode*>(ptr);
}
/*!
* \brief Construct a module from a standalone expression.
*
* Allows one to optionally pass a global function map and
* map of type definitions as well.
*
* \param expr The expression to set as the main function to the module.
* \param global_funcs The global function map.
* \param type_definitions Map of global type definitions
*
* \returns A module with expr set as the main function.
*/
TVM_DLL static IRModule FromExpr(
const RelayExpr& expr,
const tvm::Map<GlobalVar, BaseFunc>& global_funcs = {},
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions = {});

using ContainerType = ModuleNode;
/*!
* \brief Parse text format source file into an IRModule.
* \param text A string of Relay source code.
* \param source_path The path to the source file.
* \return A Relay module.
*/
TVM_DLL static IRModule FromText(const std::string& text, const std::string& source_path);
};

/*! \brief Parse Relay source into a module.
* \param source A string of Relay source code.
* \param source_name The name of the source file.
* \return A Relay module.
*/
Module FromText(const std::string& source, const std::string& source_name);

} // namespace relay
} // namespace tvm

#endif // TVM_RELAY_MODULE_H_
#endif // TVM_IR_MODULE_H_
7 changes: 3 additions & 4 deletions include/tvm/ir/type_relation.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,8 @@
namespace tvm {

// TODO(tqchen): remove after migrate Module to ir.
namespace relay {
struct Module;
}
class IRModule;


/*!
* \brief reporter that reports back to the
Expand Down Expand Up @@ -76,7 +75,7 @@ class TypeReporterNode : public Object {
* \brief Retrieve the current global module.
* \return The global module.
*/
TVM_DLL virtual relay::Module GetModule() = 0;
TVM_DLL virtual IRModule GetModule() = 0;

// solver is not serializable.
void VisitAttrs(tvm::AttrVisitor* v) {}
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ Stmt CanonicalSimplify(Stmt stmt,
* \return Canonicalized expression.
*/
TVM_DLL PrimExpr CanonicalSimplify(PrimExpr expr,
Map<Var, Range> vrange = Map<Var, Range>());
Map<Var, Range> vrange = Map<Var, Range>());

/*!
* \brief Deep compare lhs and rhs
Expand Down
18 changes: 9 additions & 9 deletions include/tvm/relay/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

#include <tvm/relay/adt.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/module.h>
#include <tvm/ir/module.h>
#include <tvm/relay/type.h>
#include <string>

Expand All @@ -49,7 +49,7 @@ namespace relay {
*
* \return The kind of the passed type.
*/
TVM_DLL Kind KindCheck(const Type& t, const Module& mod);
TVM_DLL Kind KindCheck(const Type& t, const IRModule& mod);

/*!
* \brief Check whether an expression is constant.
Expand Down Expand Up @@ -188,7 +188,7 @@ TVM_DLL tvm::Array<Var> AllVars(const Expr& expr);
*
* \return List of free vars, in the PostDFS order visited by expr.
*/
TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const Module& mod);
TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const IRModule& mod);

/*!
* \brief Get free TypeVars from type t.
Expand All @@ -201,7 +201,7 @@ TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const Module& mod);
*
* \return List of free type vars, in the PostDFS order visited by type.
*/
TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Type& t, const Module& mod);
TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Type& t, const IRModule& mod);

/*!
* \brief Get all bound type variables from expression expr.
Expand All @@ -214,7 +214,7 @@ TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Type& t, const Module& mod);
*
* \return List of bound type vars, in the PostDFS order in the expression.
*/
TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const Module& mod);
TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const IRModule& mod);

/*!
* \brief Get all bound type variables from type t.
Expand All @@ -227,7 +227,7 @@ TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const Module& mod);
*
* \return List of bound type vars, in the PostDFS order visited by type.
*/
TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Type& t, const Module& mod);
TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Type& t, const IRModule& mod);

/*!
* \brief Get all type variables in expression expr.
Expand All @@ -237,7 +237,7 @@ TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Type& t, const Module& mod);
*
* \return List of type vars, in the PostDFS order in the expression.
*/
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const Module& mod);
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const IRModule& mod);

/*!
* \brief Get all type variables in type t.
Expand All @@ -247,7 +247,7 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const Module& mod);
*
* \return List of type vars, in the PostDFS order visited by type.
*/
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const Module& mod);
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const IRModule& mod);

/*!
* \brief Collect the device mapping information of each expression.
Expand Down Expand Up @@ -277,7 +277,7 @@ TVM_DLL Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr);
* \return Returns a list of cases (as patterns) that are not handled by the match
* expression.
*/
TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const Module& mod);
TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const IRModule& mod);

/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
Expand Down
3 changes: 0 additions & 3 deletions include/tvm/relay/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,6 @@ class Id : public ObjectRef {
TVM_DEFINE_OBJECT_REF_METHODS(Id, ObjectRef, IdNode);
};


struct Module;

} // namespace relay
} // namespace tvm

Expand Down
7 changes: 5 additions & 2 deletions include/tvm/relay/error.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,16 @@
#ifndef TVM_RELAY_ERROR_H_
#define TVM_RELAY_ERROR_H_

#include <tvm/ir/module.h>

#include <string>
#include <vector>
#include <sstream>
#include <unordered_map>

#include "./base.h"
#include "./expr.h"
#include "./module.h"


namespace tvm {
namespace relay {
Expand Down Expand Up @@ -146,7 +149,7 @@ class ErrorReporter {
* \param module The module to report errors on.
* \param use_color Controls whether to colorize the output.
*/
void RenderErrors(const Module& module, bool use_color = true);
void RenderErrors(const IRModule& module, bool use_color = true);

inline bool AnyErrors() {
return errors_.size() != 0;
Expand Down
7 changes: 4 additions & 3 deletions include/tvm/relay/feature.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

#include <tvm/node/container.h>
#include <tvm/relay/expr.h>
#include <tvm/ir/module.h>

#include <bitset>

namespace tvm {
Expand Down Expand Up @@ -141,15 +143,14 @@ class FeatureSet {
*/
FeatureSet DetectFeature(const RelayExpr& expr);

struct Module;
/*!
* \brief Calculate the feature of the program.
*
* \param mod The module.
*
* \return The FeatureSet.
*/
FeatureSet DetectFeature(const Module& mod);
FeatureSet DetectFeature(const IRModule& mod);

/*!
* \brief Calculate the feature of the program.
Expand All @@ -159,7 +160,7 @@ FeatureSet DetectFeature(const Module& mod);
*
* \return The FeatureSet.
*/
inline FeatureSet DetectFeature(const Expr& expr, const Module& mod) {
inline FeatureSet DetectFeature(const Expr& expr, const IRModule& mod) {
return DetectFeature(expr) + DetectFeature(mod);
}

Expand Down
Loading

0 comments on commit c69092a

Please sign in to comment.