Skip to content

Commit

Permalink
[Relay][Module] Refactor the way we interface between different modul…
Browse files Browse the repository at this point in the history
…es of Relay. (apache#3906)

* Module refactor

* Add load module

* Add support for idempotent import

* Tweak load paths

* Move path around

* Expose C++ import functions in Python

* Fix import

* Add doc string

* Fix

* Fix lint

* Fix lint

* Fix test failure

* Add type solver

* Fix lint
  • Loading branch information
jroesch authored and wweic committed Sep 16, 2019
1 parent 67e4deb commit d8b5046
Show file tree
Hide file tree
Showing 12 changed files with 142 additions and 45 deletions.
1 change: 1 addition & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,7 @@ std::string PrettyPrint(const NodeRef& node);
std::string AsText(const NodeRef& node,
bool show_meta_data = true,
runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr);

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_EXPR_H_
29 changes: 29 additions & 0 deletions include/tvm/relay/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <string>
#include <vector>
#include <unordered_map>
#include <unordered_set>

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -185,6 +186,23 @@ class ModuleNode : public RelayNode {
*/
TVM_DLL void Update(const Module& other);

/*!
* \brief Import Relay code from the file at path.
* \param path The path of the Relay code to import.
*
* \note The path resolution behavior is standard,
* if abosolute will be the absolute file, if
* relative it will be resovled against the current
* working directory.
*/
TVM_DLL void Import(const std::string& path);

/*!
* \brief Import Relay code from the file at path, relative to the standard library.
* \param path The path of the Relay code to import.
*/
TVM_DLL void ImportFromStd(const std::string& path);

/*! \brief Construct a module from a standalone expression.
*
* Allows one to optionally pass a global function map and
Expand Down Expand Up @@ -222,6 +240,11 @@ class ModuleNode : public RelayNode {
* for convenient access
*/
std::unordered_map<int32_t, Constructor> constructor_tag_map_;

/*! \brief The files previously imported, required to ensure
importing is idempotent for each module.
*/
std::unordered_set<std::string> import_set_;
};

struct Module : public NodeRef {
Expand All @@ -235,6 +258,12 @@ struct Module : public NodeRef {
using ContainerType = ModuleNode;
};

/*! \brief Parse Relay source into a module.
* \param source A string of Relay source code.
* \param source_name The name of the source file.
* \return A Relay module.
*/
Module FromText(const std::string& source, const std::string& source_name);

} // namespace relay
} // namespace tvm
Expand Down
6 changes: 6 additions & 0 deletions include/tvm/relay/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,12 @@ class TypeReporterNode : public Node {
*/
TVM_DLL virtual void SetLocation(const NodeRef& ref) = 0;

/*!
* \brief Retrieve the current global module.
* \return The global module.
*/
TVM_DLL virtual Module GetModule() = 0;

// solver is not serializable.
void VisitAttrs(tvm::AttrVisitor* v) final {}

Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=wildcard-import, redefined-builtin, invalid-name
"""The Relay IR namespace containing the IR definition and compiler."""
from __future__ import absolute_import
import os
from sys import setrecursionlimit
from ..api import register_func
from . import base
Expand Down
15 changes: 15 additions & 0 deletions python/tvm/relay/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,22 @@
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import
"""A global module storing everything needed to interpret or compile a Relay program."""
import os
from .base import register_relay_node, RelayNode
from .. import register_func
from .._ffi import base as _base
from . import _make
from . import _module
from . import expr as _expr
from . import ty as _ty

__STD_PATH__ = os.path.join(os.path.dirname(os.path.realpath(__file__)), "std")

@register_func("tvm.relay.std_path")
def _std_path():
global __STD_PATH__
return __STD_PATH__

@register_relay_node
class Module(RelayNode):
"""The global Relay module containing collection of functions.
Expand Down Expand Up @@ -202,3 +211,9 @@ def from_expr(expr, functions=None, type_defs=None):
funcs = functions if functions is not None else {}
defs = type_defs if type_defs is not None else {}
return _module.Module_FromExpr(expr, funcs, defs)

def _import(self, file_to_import):
return _module.Module_Import(self, file_to_import)

def import_from_std(self, file_to_import):
return _module.Module_ImportFromStd(self, file_to_import)
13 changes: 4 additions & 9 deletions python/tvm/relay/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,11 @@
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""A prelude containing useful global functions and ADT definitions."""
import os
from .ty import GlobalTypeVar, TypeVar, FuncType, TupleType, scalar_type
from .expr import Var, Function, GlobalVar, Let, If, Tuple, TupleGetItem, const
from .op.tensor import add, subtract, equal
from .adt import Constructor, TypeData, Clause, Match
from .adt import PatternConstructor, PatternVar, PatternWildcard, PatternTuple
from .parser import fromtext
__PRELUDE_PATH__ = os.path.dirname(os.path.realpath(__file__))
from .module import Module

class Prelude:
Expand Down Expand Up @@ -479,12 +476,10 @@ def load_prelude(self):
Parses the portions of the Prelude written in Relay's text format and adds
them to the module.
"""
prelude_file = os.path.join(__PRELUDE_PATH__, "prelude.rly")
with open(prelude_file) as prelude:
prelude = fromtext(prelude.read())
self.mod.update(prelude)
self.id = self.mod.get_global_var("id")
self.compose = self.mod.get_global_var("compose")
# TODO(@jroesch): we should remove this helper when we port over prelude
self.mod.import_from_std("prelude.rly")
self.id = self.mod.get_global_var("id")
self.compose = self.mod.get_global_var("compose")


def __init__(self, mod=None):
Expand Down
File renamed without changes.
1 change: 0 additions & 1 deletion src/relay/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,6 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
}
}


TVM_REGISTER_API("relay._expr.Bind")
.set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef input = args[0];
Expand Down
51 changes: 51 additions & 0 deletions src/relay/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include <tvm/relay/analysis.h>
#include <tvm/relay/transform.h>
#include <sstream>
#include <fstream>
#include <unordered_set>

namespace tvm {
namespace relay {
Expand All @@ -38,6 +40,9 @@ Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs,
auto n = make_node<ModuleNode>();
n->functions = std::move(global_funcs);
n->type_definitions = std::move(global_type_defs);
n->global_type_var_map_ = {};
n->global_var_map_ = {};
n->constructor_tag_map_ = {};

for (const auto& kv : n->functions) {
// set global var map
Expand Down Expand Up @@ -85,6 +90,7 @@ void ModuleNode::AddUnchecked(const GlobalVar& var,
}

GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const {
CHECK(global_type_var_map_.defined());
auto it = global_type_var_map_.find(name);
CHECK(it != global_type_var_map_.end())
<< "Cannot find global type var " << name << " in the Module";
Expand Down Expand Up @@ -162,6 +168,7 @@ void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type) {
// set global type var map
CHECK(!global_type_var_map_.count(var->var->name_hint))
<< "Duplicate global type definition name " << var->var->name_hint;

global_type_var_map_.Set(var->var->name_hint, var);
RegisterConstructors(var, type);

Expand Down Expand Up @@ -241,6 +248,40 @@ Module ModuleNode::FromExpr(
return mod;
}

void ModuleNode::Import(const std::string& path) {
LOG(INFO) << "Importing: " << path;
if (this->import_set_.count(path) == 0) {
this->import_set_.insert(path);
std::fstream src_file(path, std::fstream::in);
std::string file_contents {
std::istreambuf_iterator<char>(src_file),
std::istreambuf_iterator<char>() };
auto mod_to_import = FromText(file_contents, path);

for (auto func : mod_to_import->functions) {
this->Add(func.first, func.second, false);
}

for (auto type : mod_to_import->type_definitions) {
this->AddDef(type.first, type.second);
}
}
}

void ModuleNode::ImportFromStd(const std::string& path) {
auto* f = tvm::runtime::Registry::Get("tvm.relay.std_path");
CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path.";
std::string std_path = (*f)();
return this->Import(std_path + "/" + path);
}

Module FromText(const std::string& source, const std::string& source_name) {
auto* f = tvm::runtime::Registry::Get("relay.fromtext");
CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path.";
Module mod = (*f)(source, source_name);
return mod;
}

TVM_REGISTER_NODE_TYPE(ModuleNode);

TVM_REGISTER_API("relay._make.Module")
Expand Down Expand Up @@ -320,6 +361,16 @@ TVM_REGISTER_API("relay._module.Module_Update")
mod->Update(from);
});

TVM_REGISTER_API("relay._module.Module_Import")
.set_body_typed<void(Module, std::string)>([](Module mod, std::string path) {
mod->Import(path);
});

TVM_REGISTER_API("relay._module.Module_ImportFromStd")
.set_body_typed<void(Module, std::string)>([](Module mod, std::string path) {
mod->ImportFromStd(path);
});;

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ModuleNode>(
[](const ModuleNode *node, tvm::IRPrinter *p) {
Expand Down
39 changes: 13 additions & 26 deletions src/relay/pass/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,

explicit TypeInferencer(Module mod, GlobalVar current_func)
: mod_(mod), current_func_(current_func),
err_reporter(), solver_(current_func, &this->err_reporter) {
err_reporter(), solver_(current_func, mod, &this->err_reporter) {
CHECK(mod.defined()) << "internal error: Module must be set in the type inferencer";
}

// inference the type of expr.
Expand Down Expand Up @@ -790,36 +791,22 @@ void EnsureCheckedType(const Expr& e) {
AllCheckTypePopulated().VisitExpr(e);
}

Expr InferType(const Expr& expr, const Module& mod_ref) {
if (!mod_ref.defined()) {
Module mod = ModuleNode::FromExpr(expr);
// NB(@jroesch): By adding the expression to the module we will
// type check it anyway; afterwards we can just recover type
// from the type-checked function to avoid doing unnecessary work.

Function func = mod->Lookup("main");

// FromExpr wraps a naked expression as a function, we will unbox
// it here.
if (expr.as<FunctionNode>()) {
return std::move(func);
} else {
return func->body;
}
} else {
auto e = TypeInferencer(mod_ref, mod_ref->GetGlobalVar("main")).Infer(expr);
CHECK(WellFormed(e));
auto free_tvars = FreeTypeVars(e, mod_ref);
CHECK(free_tvars.size() == 0)
<< "Found unbound type variables in " << e << ": " << free_tvars;
EnsureCheckedType(e);
return e;
}
Expr InferType(const Expr& expr, const Module& mod) {
auto main = mod->GetGlobalVar("main");
auto inferencer = TypeInferencer(mod, main);
auto e = inferencer.Infer(expr);
CHECK(WellFormed(e));
auto free_tvars = FreeTypeVars(e, mod);
CHECK(free_tvars.size() == 0)
<< "Found unbound type variables in " << e << ": " << free_tvars;
EnsureCheckedType(e);
return e;
}

Function InferType(const Function& func,
const Module& mod,
const GlobalVar& var) {
CHECK(mod.defined()) << "internal error: module must be set for type inference";
Function func_copy = Function(make_node<FunctionNode>(*func.operator->()));
func_copy->checked_type_ = func_copy->func_type_annotation();
mod->AddUnchecked(var, func_copy);
Expand Down
27 changes: 19 additions & 8 deletions src/relay/pass/type_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ class TypeSolver::Reporter : public TypeReporterNode {
location = ref;
}

TVM_DLL Module GetModule() final {
return this->solver_->module_;
}

private:
/*! \brief The location to report unification errors at. */
mutable NodeRef location;
Expand Down Expand Up @@ -526,10 +530,13 @@ class TypeSolver::Merger : public TypeFunctor<void(const Type&)> {
};

// constructor
TypeSolver::TypeSolver(const GlobalVar &current_func, ErrorReporter* err_reporter)
: reporter_(make_node<Reporter>(this)),
current_func(current_func),
err_reporter_(err_reporter) {
TypeSolver::TypeSolver(const GlobalVar& current_func, const Module& module,
ErrorReporter* err_reporter)
: reporter_(make_node<Reporter>(this)),
current_func(current_func),
err_reporter_(err_reporter),
module_(module) {
CHECK(module_.defined()) << "internal error: module must be defined";
}

// destructor
Expand Down Expand Up @@ -653,18 +660,22 @@ TVM_REGISTER_API("relay._analysis._test_type_solver")
using runtime::PackedFunc;
using runtime::TypedPackedFunc;
ErrorReporter *err_reporter = new ErrorReporter();
auto solver = std::make_shared<TypeSolver>(GlobalVarNode::make("test"), err_reporter);
auto module = ModuleNode::make({}, {});
auto dummy_fn_name = GlobalVarNode::make("test");
module->Add(dummy_fn_name, FunctionNode::make({}, TupleNode::make({}), Type(), {}, {}));
auto solver = std::make_shared<TypeSolver>(dummy_fn_name, module, err_reporter);

auto mod = [solver, err_reporter](std::string name) -> PackedFunc {
auto mod = [module, solver, err_reporter](std::string name) -> PackedFunc {
if (name == "Solve") {
return TypedPackedFunc<bool()>([solver]() {
return solver->Solve();
});
} else if (name == "Unify") {
return TypedPackedFunc<Type(Type, Type)>([solver, err_reporter](Type lhs, Type rhs) {
return TypedPackedFunc<Type(Type, Type)>(
[module, solver, err_reporter](Type lhs, Type rhs) {
auto res = solver->Unify(lhs, rhs, lhs);
if (err_reporter->AnyErrors()) {
err_reporter->RenderErrors(ModuleNode::make({}, {}), true);
err_reporter->RenderErrors(module, true);
}
return res;
});
Expand Down
4 changes: 3 additions & 1 deletion src/relay/pass/type_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ using common::LinkedList;
*/
class TypeSolver {
public:
TypeSolver(const GlobalVar& current_func, ErrorReporter* err_reporter);
TypeSolver(const GlobalVar& current_func, const Module& _mod, ErrorReporter* err_reporter);
~TypeSolver();
/*!
* \brief Add a type constraint to the solver.
Expand Down Expand Up @@ -179,6 +179,8 @@ class TypeSolver {
GlobalVar current_func;
/*! \brief Error reporting. */
ErrorReporter* err_reporter_;
/*! \brief The module. */
Module module_;

/*!
* \brief GetTypeNode that is corresponds to t.
Expand Down

0 comments on commit d8b5046

Please sign in to comment.