diff --git a/python/tvm/_ffi/_ctypes/vmobj.py b/python/tvm/_ffi/_ctypes/vmobj.py index 498167c7819e3..59930e55c3829 100644 --- a/python/tvm/_ffi/_ctypes/vmobj.py +++ b/python/tvm/_ffi/_ctypes/vmobj.py @@ -39,7 +39,6 @@ def _return_object(x): tag = ctypes.c_int() check_call(_LIB.TVMGetObjectTag(handle, ctypes.byref(tag))) cls = OBJECT_TYPE.get(tag.value, ObjectBase) - print('here') obj = cls(handle) return obj diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index 5e1a3c6b338d6..572d9bf399ae6 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -42,8 +42,8 @@ def _update_target(target): dev_type = tvm.expr.IntImm("int32", tvm.nd.context(dev).device_type) tgts[dev_type] = tvm.target.create(tgt) else: - raise TypeError("target is expected to be str or " + - "tvm.target.Target, but received " + + raise TypeError("target is expected to be str, tvm.target.Target, " + + "or dict of str to str/tvm.target.Target, but received " + "{}".format(type(target))) return tgts @@ -134,7 +134,7 @@ def compile(self, mod, target=None, target_host=None): The Relay module to build. target : str, :any:`tvm.target.Target`, or dict of str(i.e. - device/context name) to str/tvm.target.Target, optional + device/context name) to str/tvm.target.Target, optional For heterogeneous compilation, it is a dictionary indicating context to target mapping. For homogeneous compilation, it is a build target. @@ -178,7 +178,8 @@ class VMExecutor(Executor): The target option to build the function. """ def __init__(self, mod, ctx, target): - assert mod is not None + if mod is None: + raise RuntimeError("Must provide module to get VM executor.") self.mod = mod self.ctx = ctx self.target = target diff --git a/python/tvm/relay/backend/vmobj.py b/python/tvm/relay/backend/vmobj.py index c4d96b2890989..1bd12c8db3007 100644 --- a/python/tvm/relay/backend/vmobj.py +++ b/python/tvm/relay/backend/vmobj.py @@ -22,13 +22,15 @@ from tvm import ndarray as _nd from . import _vmobj +# TODO(@icemelon9): Add ClosureObject + @register_object class TensorObject(Object): """Tensor object.""" tag = ObjectTag.TENSOR def __init__(self, handle): - """Constructs a tensor object + """Constructs a Tensor object Parameters ---------- @@ -60,7 +62,7 @@ class DatatypeObject(Object): tag = ObjectTag.DATATYPE def __init__(self, handle): - """Constructs a tensor object + """Constructs a Datatype object Parameters ---------- @@ -69,8 +71,8 @@ def __init__(self, handle): Returns ------- - obj : TensorObject - A tensor object. + obj : DatatypeObject + A Datatype object. """ super(DatatypeObject, self).__init__(handle) self.tag = _vmobj.GetDatatypeTag(self) diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index b9ff8728113f0..05d388e3141e8 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -230,7 +230,6 @@ class VMCompiler : ExprFunctor { void VisitExpr_(const LetNode* let_node) { DLOG(INFO) << let_node->value; this->VisitExpr(let_node->value); - DLOG(INFO) << this->last_register_; var_register_map_.insert({let_node->var, this->last_register_}); this->VisitExpr(let_node->body); } @@ -443,23 +442,21 @@ class VMCompiler : ExprFunctor { protected: /*! \brief Store the expression a variable points to. */ std::unordered_map expr_map_; - + /*! \brief Instructions in the VMFunction. */ std::vector instructions_; - + /*! \brief Parameter names of the function. */ std::vector params_; - - // var -> register num + /*! \brief Map from var to register number. */ std::unordered_map var_register_map_; - + /*! \brief Last used register number. */ size_t last_register_; - - // Total number of virtual registers allocated + /*! \brief Total number of virtual registers allocated. */ size_t registers_num_; + /*! \brief Compiler engine to lower primitive functions. */ CompileEngine engine_; - /*! \brief Global shared meta data */ VMCompilerContext* context_; - + /*! \brief Target devices. */ TargetsMap targets_; }; @@ -599,9 +596,13 @@ class VMBuildModule : public runtime::ModuleNode { } protected: + /*! \brief Target devices. */ TargetsMap targets_; + /*! \brief Target host device. */ tvm::Target target_host_; + /*! \brief Global shared meta data */ VMCompilerContext context_; + /*! \brief Compiled virtual machine. */ std::shared_ptr vm_; }; diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 24ef664e651a9..255dc23d38d4a 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -22,7 +22,6 @@ from tvm import relay from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.prelude import Prelude -#from tvm.relay.vm import BuildModule def veval(f, *args, ctx=tvm.cpu(), target="llvm"): if isinstance(f, relay.Expr):