diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index c5cd6bb9e4ab..b1b8d6a7154e 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -575,6 +575,7 @@ std::string PrettyPrint(const NodeRef& node); std::string AsText(const NodeRef& node, bool show_meta_data = true, runtime::TypedPackedFunc annotate = nullptr); + } // namespace relay } // namespace tvm #endif // TVM_RELAY_EXPR_H_ diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index 3496c8815467..ee9b4873d28a 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -33,6 +33,7 @@ #include #include #include +#include namespace tvm { namespace relay { @@ -185,6 +186,23 @@ class ModuleNode : public RelayNode { */ TVM_DLL void Update(const Module& other); + /*! + * \brief Import Relay code from the file at path. + * \param path The path of the Relay code to import. + * + * \note The path resolution behavior is standard, + * if abosolute will be the absolute file, if + * relative it will be resovled against the current + * working directory. + */ + TVM_DLL void Import(const std::string& path); + + /*! + * \brief Import Relay code from the file at path, relative to the standard library. + * \param path The path of the Relay code to import. + */ + TVM_DLL void ImportFromStd(const std::string& path); + /*! \brief Construct a module from a standalone expression. * * Allows one to optionally pass a global function map and @@ -222,6 +240,11 @@ class ModuleNode : public RelayNode { * for convenient access */ std::unordered_map constructor_tag_map_; + + /*! \brief The files previously imported, required to ensure + importing is idempotent for each module. + */ + std::unordered_set import_set_; }; struct Module : public NodeRef { @@ -235,6 +258,12 @@ struct Module : public NodeRef { using ContainerType = ModuleNode; }; +/*! \brief Parse Relay source into a module. + * \param source A string of Relay source code. + * \param source_name The name of the source file. + * \return A Relay module. + */ +Module FromText(const std::string& source, const std::string& source_name); } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index d509fde2a875..16e36785c533 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -410,6 +410,12 @@ class TypeReporterNode : public Node { */ TVM_DLL virtual void SetLocation(const NodeRef& ref) = 0; + /*! + * \brief Retrieve the current global module. + * \return The global module. + */ + TVM_DLL virtual Module GetModule() = 0; + // solver is not serializable. void VisitAttrs(tvm::AttrVisitor* v) final {} diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 092cd01d1d4a..ceb98c4d251e 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -17,6 +17,7 @@ # pylint: disable=wildcard-import, redefined-builtin, invalid-name """The Relay IR namespace containing the IR definition and compiler.""" from __future__ import absolute_import +import os from sys import setrecursionlimit from ..api import register_func from . import base diff --git a/python/tvm/relay/module.py b/python/tvm/relay/module.py index e0511a257e6d..57980dd09cf2 100644 --- a/python/tvm/relay/module.py +++ b/python/tvm/relay/module.py @@ -16,13 +16,22 @@ # under the License. # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import """A global module storing everything needed to interpret or compile a Relay program.""" +import os from .base import register_relay_node, RelayNode +from .. import register_func from .._ffi import base as _base from . import _make from . import _module from . import expr as _expr from . import ty as _ty +__STD_PATH__ = os.path.join(os.path.dirname(os.path.realpath(__file__)), "std") + +@register_func("tvm.relay.std_path") +def _std_path(): + global __STD_PATH__ + return __STD_PATH__ + @register_relay_node class Module(RelayNode): """The global Relay module containing collection of functions. @@ -202,3 +211,9 @@ def from_expr(expr, functions=None, type_defs=None): funcs = functions if functions is not None else {} defs = type_defs if type_defs is not None else {} return _module.Module_FromExpr(expr, funcs, defs) + + def _import(self, file_to_import): + return _module.Module_Import(self, file_to_import) + + def import_from_std(self, file_to_import): + return _module.Module_ImportFromStd(self, file_to_import) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index f9a7d3dcaf37..d05b669ee7f1 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -16,14 +16,11 @@ # under the License. # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """A prelude containing useful global functions and ADT definitions.""" -import os from .ty import GlobalTypeVar, TypeVar, FuncType, TupleType, scalar_type from .expr import Var, Function, GlobalVar, Let, If, Tuple, TupleGetItem, const from .op.tensor import add, subtract, equal from .adt import Constructor, TypeData, Clause, Match from .adt import PatternConstructor, PatternVar, PatternWildcard, PatternTuple -from .parser import fromtext -__PRELUDE_PATH__ = os.path.dirname(os.path.realpath(__file__)) from .module import Module class Prelude: @@ -479,12 +476,10 @@ def load_prelude(self): Parses the portions of the Prelude written in Relay's text format and adds them to the module. """ - prelude_file = os.path.join(__PRELUDE_PATH__, "prelude.rly") - with open(prelude_file) as prelude: - prelude = fromtext(prelude.read()) - self.mod.update(prelude) - self.id = self.mod.get_global_var("id") - self.compose = self.mod.get_global_var("compose") + # TODO(@jroesch): we should remove this helper when we port over prelude + self.mod.import_from_std("prelude.rly") + self.id = self.mod.get_global_var("id") + self.compose = self.mod.get_global_var("compose") def __init__(self, mod=None): diff --git a/python/tvm/relay/prelude.rly b/python/tvm/relay/std/prelude.rly similarity index 100% rename from python/tvm/relay/prelude.rly rename to python/tvm/relay/std/prelude.rly diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index da9f7b8d19b9..6a2db6b46d64 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -444,7 +444,6 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { } } - TVM_REGISTER_API("relay._expr.Bind") .set_body([](TVMArgs args, TVMRetValue* ret) { NodeRef input = args[0]; diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index dbaea7f02fc7..2601f355d03e 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -26,6 +26,8 @@ #include #include #include +#include +#include namespace tvm { namespace relay { @@ -38,6 +40,9 @@ Module ModuleNode::make(tvm::Map global_funcs, auto n = make_node(); n->functions = std::move(global_funcs); n->type_definitions = std::move(global_type_defs); + n->global_type_var_map_ = {}; + n->global_var_map_ = {}; + n->constructor_tag_map_ = {}; for (const auto& kv : n->functions) { // set global var map @@ -85,6 +90,7 @@ void ModuleNode::AddUnchecked(const GlobalVar& var, } GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const { + CHECK(global_type_var_map_.defined()); auto it = global_type_var_map_.find(name); CHECK(it != global_type_var_map_.end()) << "Cannot find global type var " << name << " in the Module"; @@ -162,6 +168,7 @@ void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type) { // set global type var map CHECK(!global_type_var_map_.count(var->var->name_hint)) << "Duplicate global type definition name " << var->var->name_hint; + global_type_var_map_.Set(var->var->name_hint, var); RegisterConstructors(var, type); @@ -241,6 +248,40 @@ Module ModuleNode::FromExpr( return mod; } +void ModuleNode::Import(const std::string& path) { + LOG(INFO) << "Importing: " << path; + if (this->import_set_.count(path) == 0) { + this->import_set_.insert(path); + std::fstream src_file(path, std::fstream::in); + std::string file_contents { + std::istreambuf_iterator(src_file), + std::istreambuf_iterator() }; + auto mod_to_import = FromText(file_contents, path); + + for (auto func : mod_to_import->functions) { + this->Add(func.first, func.second, false); + } + + for (auto type : mod_to_import->type_definitions) { + this->AddDef(type.first, type.second); + } + } +} + +void ModuleNode::ImportFromStd(const std::string& path) { + auto* f = tvm::runtime::Registry::Get("tvm.relay.std_path"); + CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path."; + std::string std_path = (*f)(); + return this->Import(std_path + "/" + path); +} + +Module FromText(const std::string& source, const std::string& source_name) { + auto* f = tvm::runtime::Registry::Get("relay.fromtext"); + CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path."; + Module mod = (*f)(source, source_name); + return mod; +} + TVM_REGISTER_NODE_TYPE(ModuleNode); TVM_REGISTER_API("relay._make.Module") @@ -320,6 +361,16 @@ TVM_REGISTER_API("relay._module.Module_Update") mod->Update(from); }); +TVM_REGISTER_API("relay._module.Module_Import") +.set_body_typed([](Module mod, std::string path) { + mod->Import(path); +}); + +TVM_REGISTER_API("relay._module.Module_ImportFromStd") +.set_body_typed([](Module mod, std::string path) { + mod->ImportFromStd(path); +});; + TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch( [](const ModuleNode *node, tvm::IRPrinter *p) { diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index e8bdc090c947..5b9b25bd61f9 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -108,7 +108,8 @@ class TypeInferencer : private ExprFunctor, explicit TypeInferencer(Module mod, GlobalVar current_func) : mod_(mod), current_func_(current_func), - err_reporter(), solver_(current_func, &this->err_reporter) { + err_reporter(), solver_(current_func, mod, &this->err_reporter) { + CHECK(mod.defined()) << "internal error: Module must be set in the type inferencer"; } // inference the type of expr. @@ -790,36 +791,22 @@ void EnsureCheckedType(const Expr& e) { AllCheckTypePopulated().VisitExpr(e); } -Expr InferType(const Expr& expr, const Module& mod_ref) { - if (!mod_ref.defined()) { - Module mod = ModuleNode::FromExpr(expr); - // NB(@jroesch): By adding the expression to the module we will - // type check it anyway; afterwards we can just recover type - // from the type-checked function to avoid doing unnecessary work. - - Function func = mod->Lookup("main"); - - // FromExpr wraps a naked expression as a function, we will unbox - // it here. - if (expr.as()) { - return std::move(func); - } else { - return func->body; - } - } else { - auto e = TypeInferencer(mod_ref, mod_ref->GetGlobalVar("main")).Infer(expr); - CHECK(WellFormed(e)); - auto free_tvars = FreeTypeVars(e, mod_ref); - CHECK(free_tvars.size() == 0) - << "Found unbound type variables in " << e << ": " << free_tvars; - EnsureCheckedType(e); - return e; - } +Expr InferType(const Expr& expr, const Module& mod) { + auto main = mod->GetGlobalVar("main"); + auto inferencer = TypeInferencer(mod, main); + auto e = inferencer.Infer(expr); + CHECK(WellFormed(e)); + auto free_tvars = FreeTypeVars(e, mod); + CHECK(free_tvars.size() == 0) + << "Found unbound type variables in " << e << ": " << free_tvars; + EnsureCheckedType(e); + return e; } Function InferType(const Function& func, const Module& mod, const GlobalVar& var) { + CHECK(mod.defined()) << "internal error: module must be set for type inference"; Function func_copy = Function(make_node(*func.operator->())); func_copy->checked_type_ = func_copy->func_type_annotation(); mod->AddUnchecked(var, func_copy); diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 743a4c7774b8..31edd3b0e80e 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -61,6 +61,10 @@ class TypeSolver::Reporter : public TypeReporterNode { location = ref; } + TVM_DLL Module GetModule() final { + return this->solver_->module_; + } + private: /*! \brief The location to report unification errors at. */ mutable NodeRef location; @@ -526,10 +530,13 @@ class TypeSolver::Merger : public TypeFunctor { }; // constructor -TypeSolver::TypeSolver(const GlobalVar ¤t_func, ErrorReporter* err_reporter) - : reporter_(make_node(this)), - current_func(current_func), - err_reporter_(err_reporter) { +TypeSolver::TypeSolver(const GlobalVar& current_func, const Module& module, + ErrorReporter* err_reporter) + : reporter_(make_node(this)), + current_func(current_func), + err_reporter_(err_reporter), + module_(module) { + CHECK(module_.defined()) << "internal error: module must be defined"; } // destructor @@ -653,18 +660,22 @@ TVM_REGISTER_API("relay._analysis._test_type_solver") using runtime::PackedFunc; using runtime::TypedPackedFunc; ErrorReporter *err_reporter = new ErrorReporter(); - auto solver = std::make_shared(GlobalVarNode::make("test"), err_reporter); + auto module = ModuleNode::make({}, {}); + auto dummy_fn_name = GlobalVarNode::make("test"); + module->Add(dummy_fn_name, FunctionNode::make({}, TupleNode::make({}), Type(), {}, {})); + auto solver = std::make_shared(dummy_fn_name, module, err_reporter); - auto mod = [solver, err_reporter](std::string name) -> PackedFunc { + auto mod = [module, solver, err_reporter](std::string name) -> PackedFunc { if (name == "Solve") { return TypedPackedFunc([solver]() { return solver->Solve(); }); } else if (name == "Unify") { - return TypedPackedFunc([solver, err_reporter](Type lhs, Type rhs) { + return TypedPackedFunc( + [module, solver, err_reporter](Type lhs, Type rhs) { auto res = solver->Unify(lhs, rhs, lhs); if (err_reporter->AnyErrors()) { - err_reporter->RenderErrors(ModuleNode::make({}, {}), true); + err_reporter->RenderErrors(module, true); } return res; }); diff --git a/src/relay/pass/type_solver.h b/src/relay/pass/type_solver.h index 28579633c1c6..4a6d2cfa7756 100644 --- a/src/relay/pass/type_solver.h +++ b/src/relay/pass/type_solver.h @@ -63,7 +63,7 @@ using common::LinkedList; */ class TypeSolver { public: - TypeSolver(const GlobalVar& current_func, ErrorReporter* err_reporter); + TypeSolver(const GlobalVar& current_func, const Module& _mod, ErrorReporter* err_reporter); ~TypeSolver(); /*! * \brief Add a type constraint to the solver. @@ -179,6 +179,8 @@ class TypeSolver { GlobalVar current_func; /*! \brief Error reporting. */ ErrorReporter* err_reporter_; + /*! \brief The module. */ + Module module_; /*! * \brief GetTypeNode that is corresponds to t.