From c908b7a466fba62348374e8a314214f18d431216 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 16 Sep 2018 13:58:00 -0700 Subject: [PATCH] Clean up and docs --- include/tvm/relay/environment.h | 46 ++++- include/tvm/relay/expr_visitor.h | 14 +- include/tvm/relay/op.h | 2 +- python/tvm/relay/env.py | 2 + python/tvm/relay/ir_pass.py | 227 +---------------------- tests/scripts/task_python_integration.sh | 2 + 6 files changed, 52 insertions(+), 241 deletions(-) diff --git a/include/tvm/relay/environment.h b/include/tvm/relay/environment.h index fa805e2944d0..75ddc88674e6 100644 --- a/include/tvm/relay/environment.h +++ b/include/tvm/relay/environment.h @@ -51,28 +51,56 @@ class EnvironmentNode : public RelayNode { TVM_DLL static Environment make(tvm::Map global_funcs); + /*! \brief Add a function to the global environment. + * \param var The name of the global function. + * \param func The function. + * \param update Controls whether you can replace a definition in the + * environment. + */ void Add(const GlobalVar& var, const Function& func, bool update = false); + + /*! \brief Update a function in the global environment. + * \param var The name of the global function to update. + * \param func The new function. + */ void Update(const GlobalVar& var, const Function& func); + + /*! \brief Remove a function from the global environment. + * \param var The name of the global function to update. + */ void Remove(const GlobalVar& var); - /*! \brief Lookup a global function by its variable. */ + /*! \brief Lookup a global function by its variable. + * \param str The unique string specifying the global variable. + * \returns The global variable. + */ GlobalVar GetGlobalVar(const std::string& str); - /*! \brief Lookup a global function by its variable. */ - Function Lookup(const GlobalVar& id); + /*! \brief Lookup a global function by its variable. + * \param var The global var to lookup. + * \returns The function named by the variable argument. + */ + Function Lookup(const GlobalVar& var); - /*! \brief Lookup a global function by its string name */ - Function Lookup(const std::string& s); + /*! \brief Lookup a global function by its string name + * \param name The name of the function. + * \returns The function named by the argument. + */ + Function Lookup(const std::string& name); - // TODO(@jroesch, @tqchen): what are the semantics here - void Merge(const Environment& env); + /*! \brief Combine with another Environment. + * \param other The other environment. + */ + void Merge(const Environment& other); using Transformer = runtime::TypedPackedFunc(const Environment&)>; - /*! \brief Apply a function over every function in the global environment. */ - void Transform(Transformer tranformer); + /*! \brief Apply a function over every function in the global environment. + * \param transformer The transformation function. + */ + void Transform(Transformer transformer); static constexpr const char* _type_key = "relay.Environment"; TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node); diff --git a/include/tvm/relay/expr_visitor.h b/include/tvm/relay/expr_visitor.h index 748b8ac02f97..4a26dcbd32e7 100644 --- a/include/tvm/relay/expr_visitor.h +++ b/include/tvm/relay/expr_visitor.h @@ -97,9 +97,9 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor { auto type = this->VisitType(op->type); return ParamNode::make(var, type); } else { - CHECK(false) << "the default param visitor expected a Var found: " + LOG(FATAL) << "the default param visitor expected a Var found: " << var_expr << std::endl; - __builtin_unreachable(); + return Expr(); } } @@ -112,10 +112,10 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor { auto ty_param_ref = GetRef(ty_param); ty_params.push_back(ty_param_ref); } else { - CHECK(false) + LOG(FATAL) << "the default function visitor expected a TypeParam found: " << ty_param_type << std::endl; - __builtin_unreachable(); + return Expr(); } } @@ -128,7 +128,7 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor { } else { CHECK(false) << "the default function visitor expected a Param found: " << param_expr << std::endl; - __builtin_unreachable(); + return Expr(); } } @@ -165,9 +165,9 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor { auto body = this->VisitExpr(op->body); return LetNode::make(var, value, body, type); } else { - CHECK(false) << "the default let visitor expected a Var found: " + LOG(FATAL) << "the default let visitor expected a Var found: " << var_expr << std::endl; - __builtin_unreachable(); + return Expr(); } } diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index a3037d3bebf4..3ab8c778c76d 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -208,7 +208,7 @@ class OpRegistry { } return *this; } - /*! \return The global single retistry */ + /*! \return The global single registry */ TVM_DLL static ::dmlc::Registry* Registry(); private: diff --git a/python/tvm/relay/env.py b/python/tvm/relay/env.py index 4c73db0b524b..6f4362a77c2d 100644 --- a/python/tvm/relay/env.py +++ b/python/tvm/relay/env.py @@ -4,11 +4,13 @@ from . import _make from . import _env + @register_relay_node class Environment(NodeBase): """The global Relay environment containing functions, options and more. """ + def __init__(self, funcs) -> None: """Construct an environment. diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index ca396404610a..37f7001c460b 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -1,232 +1,11 @@ # pylint: disable=no-else-return, # pylint: disable=unidiomatic-typecheck -"""The optimizer for Relay. +"""The set of passes for Relay. -Exposes an interface for configuring the optimizer and scripting -it directly in Python. +Exposes an interface for configuring the passes and scripting +them in Python. """ -from typing import TypeVar, Generic, Union -from typing import Dict, Tuple, List, Callable -import tvm - -from .expr import Expr -from .expr import Function, Let, Call, Var -from .expr import GlobalVar, If, Constant -from .type import Type, TypeParam -from .env import Environment -from .op import Op -from .op.op import specialize_op -# import relay.make as relay_mk -# from relay import ir -# from relay.env import Environment -# from relay.tyck import check_expr -# from relay.first_order_reverse_ad import fo_with_gradient -# from relay.anf import to_anf from . import _ir_pass # Expose checking expression, should rename to infer_type. -# pylint: disable=invalid-name check_expr = _ir_pass.check_expr - -# # pylint: disable=invalid-name -# concretize = _opt.concretize - -# # pylint: disable=invalid-name -# optimize = _opt.optimize - -# # pylint: disable=invalid-name -# type_specialize = _opt.type_specialize - -# # pylint: disable=invalid-name -# compile_ops_to_module = _opt.compile_ops_to_module - - -@tvm.register_func("relay.mangle") -def mangle(name: str, types: List[Type]) -> str: - for typ in types: - name += str(typ) + "_" - return name - - -T = TypeVar('T') - - -class AbstractExprVisitor(Generic[T]): - """A functional visitor over Expr in Python.""" - - # pylint: disable=no-else-return - def visit(self, expr: Expr) -> T: - """Apply the visitor to an expression.""" - if isinstance(expr, Function): - return self.visit_function(expr) - elif isinstance(expr, Call): - return self.visit_call(expr) - elif isinstance(expr, Let): - return self.visit_let(expr) - elif isinstance(expr, Var): - return self.visit_local_var(expr) - elif isinstance(expr, GlobalVar): - return self.visit_global_var(expr) - elif isinstance(expr, If): - return self.visit_if(expr) - elif isinstance(expr, Tuple): - return self.visit_tuple(expr) - elif isinstance(expr, Constant): - return self.visit_constant(expr) - else: - raise Exception(f"warning unhandled case: {type(expr)}") - - def visit_function(self, _: Function) -> T: - raise Exception("Abstract method please implement me.") - - def visit_let(self, _: Let) -> T: - raise Exception("Abstract method please implement me.") - - def visit_call(self, _: Call) -> T: - raise Exception("Abstract method please implement me.") - - def visit_local_id(self, _: Var) -> T: - raise Exception("Abstract method please implement me.") - - def visit_type(self, typ: Type) -> Type: - return typ - - def visit_if(self, _: If) -> T: - raise Exception("Abstract method please implement me.") - - def visit_tuple(self, _: Tuple) -> T: - raise Exception("Abstract method please implement me.") - - def visit_constant(self, _: Constant) -> T: - raise Exception("Abstract method please implement me.") - - def visit_global_var(self, _: GlobalVar) -> T: - raise Exception("Abstract method please implement me.") - - @classmethod - def to_pass(cls) -> Callable[[Environment], Callable[[GlobalVar, Function], Function]]: - def _outer_wrapper(env): - visitor = cls(env) - - def _inner_wrapper(_, func): - return visitor.visit(func) - return _inner_wrapper - return _outer_wrapper - - -class ExprVisitor(AbstractExprVisitor[Expr]): - """A functional visitor over Expr in Python.""" - - def visit_function(self, fn: Function) -> Expr: - new_body = self.visit(fn.body) - return Function( - list(fn.params), - fn.ret_type, new_body, - fn.type_params) - - def visit_let(self, let: Let) -> Expr: - new_var = self.visit(let.var) - new_value_type = self.visit_type(let.value_type) - new_val = self.visit(let.value) - new_body = self.visit(let.body) - return Let(new_var, new_val, new_body, new_value_type) - - def visit_call(self, call: Call) -> Expr: - new_fn = self.visit(call.op) - new_args = [self.visit(arg) for arg in call.args] - return Call(new_fn, new_args, call.attrs) - - def visit_local_var(self, local_var: Var) -> Expr: - return local_var - - def visit_global_id(self, global_var: GlobalVar) -> Expr: - return global_var - - def visit_if(self, ite: If) -> Expr: - return If( - self.visit(ite.guard), - self.visit(ite.true_b), - self.visit(ite.false_b)) - - def visit_tuple(self, tup: Tuple) -> Expr: - return Tuple([self.visit(field) for field in tup.fields]) - - def visit_constant(self, const: Constant) -> Expr: - return const - - -MMCacheKey = Tuple[Union[GlobalVar, str], List[Type]] - - -class Monomorphize(ExprVisitor): - """A monomorphization pass. - - Implements what is known as "monomorphization" in - classic compiler literature. This pass removes - polymorphism replacing calls to functions and - operators with type specialized versions. - """ - monomorph_map: Dict[MMCacheKey, Union[Op, Function]] - - # pylint: disable=super-init-not-called - def __init__(self, env: Environment) -> None: - self.env = env - # Stores (GlobalVar, Type), should eventually store attributes. - self.monomorph_map = {} - - # pylint: disable=no-else-return - def visit_call(self, call: Call) -> Expr: - cache_key = (call.op, call.type_args) - new_args = [self.visit(arg) for arg in call.args] - - if cache_key in self.monomorph_map: - op = self.monomorph_map[cache_key] - new_args = [self.visit(arg) for arg in call.args] - return Call(op, new_args, call.attrs) - else: - if isinstance(call.op, Op): - poly_name = call.op.name - mono_name = mangle(poly_name, call.type_args) - for arg in call.type_args: - if isinstance(arg, TypeParam): - # raise Exception("...") # Fix me in the morning!!! - return call - - mono_op = specialize_op(poly_name, mono_name, call.type_args) - self.monomorph_map[cache_key] = mono_op - return Call(mono_op, new_args, call.attrs, []) - elif isinstance(call.op, GlobalVar): - return call - # defn = self.env.lookup(call.op) - # new_id = self.env.global_id(defn.id.name + str(1)) - # cache_key = (call.op, call.type_args) - # self.monomorph_map[cache_key] = new_id - # new_body = self.visit(type_specialize(call.type_args, defn.body)) - # new_body = Function( - # [], new_body.params, new_body.ret_type, new_body.body) - # new_ty = check_expr(self.env, new_body) - # # TODO(@jroesch): move into C++ - # # TODO(@joresch): implement and call name mangler - # defn = Defn(new_id, new_ty, new_body) - # self.env.add(defn) - # self.visit_item(defn) - # return Call(new_id, call.args, call.attrs) - - elif isinstance(call.op, Function): - return call - # new_func = type_specialize(call.type_args, call.op) - # new_func = self.visit(new_func) - # new_func = Function([], - # new_func.params, - # new_func.ret_type, - # new_func.body) - # check_expr(self.env, new_func) - # return Call(new_func, call.args, call.attrs) - else: - new_fn = self.visit(call.op) - return Call(new_fn, new_args, call.attrs) - - -# TODO(@jroesch): Fix up my type -__tgt_host__ = __tgt__ = "llvm" -__relay_tvm_context__ = tvm.cpu() diff --git a/tests/scripts/task_python_integration.sh b/tests/scripts/task_python_integration.sh index 8104bf079502..7dcd5c921905 100755 --- a/tests/scripts/task_python_integration.sh +++ b/tests/scripts/task_python_integration.sh @@ -18,6 +18,8 @@ TVM_FFI=cython python -m nose -v tests/python/integration || exit -1 TVM_FFI=ctypes python3 -m nose -v tests/python/integration || exit -1 TVM_FFI=cython python -m nose -v tests/python/contrib || exit -1 TVM_FFI=ctypes python3 -m nose -v tests/python/contrib || exit -1 +TVM_FFI=cython python -m nose -v tests/python/relay || exit -1 +TVM_FFI=ctypes python3 -m nose -v tests/python/relay || exit -1 # Do not enabke OpenGL # TVM_FFI=cython python -m nose -v tests/webgl || exit -1