Skip to content

Commit

Permalink
use _ffi_api
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Feb 11, 2020
1 parent a31b23e commit 2b92133
Show file tree
Hide file tree
Showing 8 changed files with 24 additions and 86 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def compile(mod, target=None, target_host=None, params=None):
Returns
-------
exec : tvm.runtime.vmf.Executable
exec : tvm.runtime.vm.Executable
The VM executable that contains both library code and bytecode.
"""
compiler = VMCompiler()
Expand Down
21 changes: 0 additions & 21 deletions python/tvm/runtime/_vm.py

This file was deleted.

7 changes: 4 additions & 3 deletions python/tvm/runtime/profiler_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,19 @@
Provides extra APIs for profiling vm execution.
"""
from . import vm, _vm
from tvm.runtime import _ffi_api
from . import vm

def enabled():
"""Whether vm profiler is enabled."""
return hasattr(_vm, "_VirtualMachineDebug")
return hasattr(_ffi_api, "_VirtualMachineDebug")

class VirtualMachineProfiler(vm.VirtualMachine):
"""Relay profile VM runtime."""
def __init__(self, mod):
super(VirtualMachineProfiler, self).__init__(mod)
m = mod.module if isinstance(mod, vm.Executable) else mod
self.mod = _vm._VirtualMachineDebug(m)
self.mod = _ffi_api._VirtualMachineDebug(m)
self._init = self.mod["init"]
self._invoke = self.mod["invoke"]
self._get_stat = self.mod["get_stat"]
Expand Down
64 changes: 10 additions & 54 deletions python/tvm/runtime/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,20 @@
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, invalid-name, redefined-builtin
"""
The Relay Virtual Machine.
The Relay Virtual Machine runtime.
Implements a Python interface to compiling and executing on the Relay VM.
Implements a Python interface to executing the compiled VM object.
"""
import numpy as np

import tvm
from tvm.runtime import Object, container
from tvm.relay import expr as _expr
from tvm.runtime import _ffi_api
from tvm._ffi.runtime_ctypes import TVMByteArray
from tvm._ffi import base as _base
from tvm.relay.backend.interpreter import Executor
from . import _vm

def _convert(arg, cargs):
if isinstance(arg, _expr.Constant):
cargs.append(arg.data)
elif isinstance(arg, Object):
if isinstance(arg, Object):
cargs.append(arg)
elif isinstance(arg, np.ndarray):
nd_arr = tvm.nd.array(arg, ctx=tvm.cpu(0))
Expand Down Expand Up @@ -163,7 +159,7 @@ def load_exec(bytecode, lib):
raise TypeError("lib is expected to be the type of tvm.runtime.Module" +
", but received {}".format(type(lib)))

return Executable(_vm.Load_Executable(bytecode, lib))
return Executable(_ffi_api.Load_Executable(bytecode, lib))

@property
def lib(self):
Expand Down Expand Up @@ -197,9 +193,9 @@ def primitive_ops(self):
The list of primitive ops.
"""
ret = []
num_primitives = _vm.GetNumOfPrimitives(self.module)
num_primitives = _ffi_api.GetNumOfPrimitives(self.module)
for i in range(num_primitives):
ret.append(_vm.GetPrimitiveFields(self.module, i))
ret.append(_ffi_api.GetPrimitiveFields(self.module, i))
return ret

@property
Expand Down Expand Up @@ -240,9 +236,9 @@ def globals(self):
The globals contained in the executable.
"""
ret = []
num_globals = _vm.GetNumOfGlobals(self.module)
num_globals = _ffi_api.GetNumOfGlobals(self.module)
for i in range(num_globals):
ret.append(_vm.GetGlobalFields(self.module, i))
ret.append(_ffi_api.GetGlobalFields(self.module, i))
return ret

@property
Expand Down Expand Up @@ -272,7 +268,7 @@ def __init__(self, mod):
raise TypeError("mod is expected to be the type of Executable or " +
"tvm.Module, but received {}".format(type(mod)))
m = mod.module if isinstance(mod, Executable) else mod
self.mod = _vm._VirtualMachine(m)
self.mod = _ffi_api._VirtualMachine(m)
self._exec = mod
self._init = self.mod["init"]
self._invoke = self.mod["invoke"]
Expand Down Expand Up @@ -359,43 +355,3 @@ def run(self, *args, **kwargs):
The output.
"""
return self.invoke("main", *args, **kwargs)


class VMExecutor(Executor):
"""
An implementation of the executor interface for
the Relay VM.
Useful interface for experimentation and debugging
the VM can also be used directly from the API.
supported by `tvm.runtime.vm`.
Parameters
----------
mod : :py:class:`~tvm.relay.module.Module`
The module to support the execution.
ctx : :py:class:`~tvm.TVMContext`
The runtime context to run the code on.
target : :py:class:`Target`
The target option to build the function.
"""
def __init__(self, mod, ctx, target):
if mod is None:
raise RuntimeError("Must provide module to get VM executor.")
self.mod = mod
self.ctx = ctx
self.target = target
self.executable = compile(mod, target)
self.vm = VirtualMachine(self.executable)
self.vm.init(ctx)

def _make_executor(self, expr=None):
main = self.mod["main"]

def _vm_wrapper(*args, **kwargs):
args = self._convert_args(main, args, kwargs)
return self.vm.run(*args)

return _vm_wrapper
10 changes: 5 additions & 5 deletions src/runtime/vm/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -738,15 +738,15 @@ void Executable::LoadCodeSection(dmlc::Stream* strm) {
}
}

TVM_REGISTER_GLOBAL("runtime._vm.GetNumOfGlobals")
TVM_REGISTER_GLOBAL("runtime.GetNumOfGlobals")
.set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->());
CHECK(exec);
*rv = static_cast<int>(exec->global_map.size());
});

TVM_REGISTER_GLOBAL("runtime._vm.GetGlobalFields")
TVM_REGISTER_GLOBAL("runtime.GetGlobalFields")
.set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->());
Expand All @@ -763,7 +763,7 @@ TVM_REGISTER_GLOBAL("runtime._vm.GetGlobalFields")
*rv = globals[idx].first;
});

TVM_REGISTER_GLOBAL("runtime._vm.GetNumOfPrimitives")
TVM_REGISTER_GLOBAL("runtime.GetNumOfPrimitives")
.set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->());
Expand All @@ -772,7 +772,7 @@ TVM_REGISTER_GLOBAL("runtime._vm.GetNumOfPrimitives")
});


TVM_REGISTER_GLOBAL("runtime._vm.GetPrimitiveFields")
TVM_REGISTER_GLOBAL("runtime.GetPrimitiveFields")
.set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->());
Expand All @@ -789,7 +789,7 @@ TVM_REGISTER_GLOBAL("runtime._vm.GetPrimitiveFields")
}
});

TVM_REGISTER_GLOBAL("runtime._vm.Load_Executable")
TVM_REGISTER_GLOBAL("runtime.Load_Executable")
.set_body_typed([](
std::string code,
runtime::Module lib) {
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/vm/profiler/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ runtime::Module CreateVirtualMachineDebug(const Executable* exec) {
return runtime::Module(vm);
}

TVM_REGISTER_GLOBAL("relay._vm._VirtualMachineDebug")
TVM_REGISTER_GLOBAL("runtime._VirtualMachineDebug")
.set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->());
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1057,7 +1057,7 @@ runtime::Module CreateVirtualMachine(const Executable* exec) {
return runtime::Module(vm);
}

TVM_REGISTER_GLOBAL("runtime._vm._VirtualMachine")
TVM_REGISTER_GLOBAL("runtime._VirtualMachine")
.set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->());
Expand Down
2 changes: 2 additions & 0 deletions tests/python/unittest/test_runtime_vm_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np

import tvm
from tvm.runtime import profiler_vm
from tvm import relay
Expand Down

0 comments on commit 2b92133

Please sign in to comment.