Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] move vm.py under runtime and adt to runtime.container.py #4855

Merged
merged 5 commits into from
Feb 11, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/tvm/autotvm/task/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def _lower(mod,
"""
# pylint: disable=import-outside-toplevel
from tvm import relay
from tvm import runtime
from tvm.relay.backend import graph_runtime_codegen

if hasattr(target, 'device_name') and target.device_name == "vta":
Expand All @@ -49,7 +50,7 @@ def _lower(mod,
grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
grc.codegen(mod["main"])
# default case
compiler = relay.vm.VMCompiler()
compiler = runtime.vm.VMCompiler()
zhiics marked this conversation as resolved.
Show resolved Hide resolved
if params:
compiler.set_params(params)
compiler.lower(mod, target=target)
Expand Down
55 changes: 1 addition & 54 deletions python/tvm/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""Container data structures used in TVM DSL."""
import tvm._ffi

from tvm.runtime import Object, ObjectTypes
from tvm.runtime import Object
from tvm.runtime.container import getitem_helper
from tvm.runtime import _ffi_node_api
from . import _api_internal
Expand Down Expand Up @@ -104,56 +104,3 @@ class LoweredFunc(Object):
MixedFunc = 0
HostFunc = 1
DeviceFunc = 2


@tvm._ffi.register_object("vm.ADT")
class ADT(Object):
"""Algebatic data type(ADT) object.
Parameters
----------
tag : int
The tag of ADT.
fields : list[Object] or tuple[Object]
The source tuple.
"""
def __init__(self, tag, fields):
for f in fields:
assert isinstance(f, ObjectTypes), "Expect object or " \
"tvm NDArray type, but received : {0}".format(type(f))
self.__init_handle_by_constructor__(_ADT, tag, *fields)

@property
def tag(self):
return _GetADTTag(self)

def __getitem__(self, idx):
return getitem_helper(
self, _GetADTFields, len(self), idx)

def __len__(self):
return _GetADTSize(self)


def tuple_object(fields=None):
"""Create a ADT object from source tuple.
Parameters
----------
fields : list[Object] or tuple[Object]
The source tuple.
Returns
-------
ret : ADT
The created object.
"""
fields = fields if fields else []
for f in fields:
assert isinstance(f, ObjectTypes), "Expect object or tvm " \
"NDArray type, but received : {0}".format(type(f))
return _Tuple(*fields)


tvm._ffi._init_api("tvm.container")
2 changes: 0 additions & 2 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@
from . import debug
from . import param_dict
from . import feature
from .backend import vm
from .backend import profiler_vm

# Root operators
from .op import Op
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import numpy as np

from tvm import container
from tvm.runtime import container
from . import _backend
from .. import _make, analysis, transform
from .. import module
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@
import numpy as np

from tvm import expr as tvm_expr
from tvm.runtime import vm
from .. import nd as _nd, target as _target, autotvm
from ..contrib import graph_runtime as _graph_rt
from . import _build_module
from . import ty as _ty
from . import expr as _expr
from .module import Module as _Module
from .backend import interpreter as _interpreter
from .backend.vm import VMExecutor

def _update_target(target):
target = target if target else _target.current_target()
Expand Down Expand Up @@ -408,5 +408,5 @@ def create_executor(kind="debug",
if kind == "graph":
return GraphExecutor(mod, ctx, target)
if kind == "vm":
return VMExecutor(mod, ctx, target)
return vm.VMExecutor(mod, ctx, target)
raise RuntimeError("unknown execution strategy: {0}".format(kind))
4 changes: 2 additions & 2 deletions python/tvm/relay/testing/py_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@
# import numpy
# import tvm
# from tvm import relay
# from tvm import import container as _container
# from tvm import nd
# from tvm.runtime import import container as _container
# from tvm.relay.backend.interpreter import RefValue, ConstructorValue
PROLOGUE = [
ast.Import([alias('numpy', None)]),
ast.Import([alias('tvm', None)]),
ast.ImportFrom('tvm', [alias('relay', None)], 0),
ast.ImportFrom('tvm', [alias('nd', None)], 0),
ast.ImportFrom('tvm', [alias('container', '_container')],
ast.ImportFrom('tvm.runtime', [alias('container', '_container')],
0),
ast.ImportFrom('tvm.relay.backend.interpreter',
[alias('RefValue', None),
Expand Down
File renamed without changes.
56 changes: 56 additions & 0 deletions python/tvm/runtime/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
# specific language governing permissions and limitations
# under the License.
"""Runtime container structures."""
import tvm._ffi

from tvm.runtime import Object, ObjectTypes

def getitem_helper(obj, elem_getter, length, idx):
"""Helper function to implement a pythonic getitem function.
Expand Down Expand Up @@ -54,3 +57,56 @@ def getitem_helper(obj, elem_getter, length, idx):
if idx < 0:
idx += length
return elem_getter(obj, idx)


@tvm._ffi.register_object("vm.ADT")
class ADT(Object):
"""Algebatic data type(ADT) object.

Parameters
----------
tag : int
The tag of ADT.

fields : list[Object] or tuple[Object]
The source tuple.
"""
def __init__(self, tag, fields):
for f in fields:
assert isinstance(f, ObjectTypes), "Expect object or " \
"tvm NDArray type, but received : {0}".format(type(f))
self.__init_handle_by_constructor__(_ADT, tag, *fields)

@property
def tag(self):
return _GetADTTag(self)

def __getitem__(self, idx):
return getitem_helper(
self, _GetADTFields, len(self), idx)

def __len__(self):
return _GetADTSize(self)


def tuple_object(fields=None):
"""Create a ADT object from source tuple.

Parameters
----------
fields : list[Object] or tuple[Object]
The source tuple.

Returns
-------
ret : ADT
The created object.
"""
fields = fields if fields else []
for f in fields:
assert isinstance(f, ObjectTypes), "Expect object or tvm " \
"NDArray type, but received : {0}".format(type(f))
return _Tuple(*fields)


tvm._ffi._init_api("tvm.runtime.container")
14 changes: 7 additions & 7 deletions python/tvm/relay/backend/vm.py → python/tvm/runtime/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@

import tvm
import tvm.runtime.ndarray as _nd
from tvm.runtime import Object
from tvm import autotvm, container
from tvm.runtime import Object, container
from tvm import autotvm
from tvm.relay import expr as _expr
from tvm._ffi.runtime_ctypes import TVMByteArray
from tvm._ffi import base as _base
from tvm.relay.backend.interpreter import Executor
from . import _vm
from .interpreter import Executor

def _convert(arg, cargs):
if isinstance(arg, _expr.Constant):
Expand Down Expand Up @@ -117,7 +117,7 @@ def save(self):
# create a Relay VM.
ctx = tvm.cpu()
target = "llvm"
executable = relay.vm.compile(mod, target)
executable = tvm.runtime.vm.compile(mod, target)
code, lib = executable.save()
# save and load the code and lib file.
tmp = tvm.contrib.util.tempdir()
Expand All @@ -128,10 +128,10 @@ def save(self):
loaded_lib = tvm.runtime.load_module(path_lib)
loaded_code = bytearray(open(tmp.relpath("code.ro"), "rb").read())
# deserialize.
des_exec = relay.vm.Executable.load_exec(loaded_code, loaded_code)
des_exec = tvm.runtime.vm.Executable.load_exec(loaded_code, loaded_code)
# execute the deserialized executable.
x_data = np.random.rand(10, 10).astype('float32')
des_vm = relay.vm.VirtualMachine(des_exec)
des_vm = tvm.runtime.vm.VirtualMachine(des_exec)
des_vm.init(ctx)
res = des_vm.run(x_data)
print(res.asnumpy())
Expand Down Expand Up @@ -556,7 +556,7 @@ class VMExecutor(Executor):

Useful interface for experimentation and debugging
the VM can also be used directly from the API.
supported by `tvm.relay.vm`.
supported by `tvm.runtime.vm`.

Parameters
----------
Expand Down
10 changes: 5 additions & 5 deletions src/runtime/container.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,22 @@ namespace runtime {

using namespace vm;

TVM_REGISTER_GLOBAL("container._GetADTTag")
TVM_REGISTER_GLOBAL("runtime.container._GetADTTag")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0];
const auto& adt = Downcast<ADT>(obj);
*rv = static_cast<int64_t>(adt.tag());
});

TVM_REGISTER_GLOBAL("container._GetADTSize")
TVM_REGISTER_GLOBAL("runtime.container._GetADTSize")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0];
const auto& adt = Downcast<ADT>(obj);
*rv = static_cast<int64_t>(adt.size());
});


TVM_REGISTER_GLOBAL("container._GetADTFields")
TVM_REGISTER_GLOBAL("runtime.container._GetADTFields")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0];
int idx = args[1];
Expand All @@ -56,7 +56,7 @@ TVM_REGISTER_GLOBAL("container._GetADTFields")
*rv = adt[idx];
});

TVM_REGISTER_GLOBAL("container._Tuple")
TVM_REGISTER_GLOBAL("runtime.container._Tuple")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::vector<ObjectRef> fields;
for (auto i = 0; i < args.size(); ++i) {
Expand All @@ -65,7 +65,7 @@ TVM_REGISTER_GLOBAL("container._Tuple")
*rv = ADT::Tuple(fields);
});

TVM_REGISTER_GLOBAL("container._ADT")
TVM_REGISTER_GLOBAL("runtime.container._ADT")
.set_body([](TVMArgs args, TVMRetValue* rv) {
int itag = args[0];
size_t tag = static_cast<size_t>(itag);
Expand Down
2 changes: 1 addition & 1 deletion tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def convert_to_list(x):
def vmobj_to_list(o):
if isinstance(o, tvm.nd.NDArray):
return [o.asnumpy().tolist()]
elif isinstance(o, tvm.container.ADT):
elif isinstance(o, tvm.runtime.container.ADT):
result = []
for f in o:
result.extend(vmobj_to_list(f))
Expand Down
3 changes: 2 additions & 1 deletion tests/python/relay/benchmarking/benchmark_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

import tvm
from tvm.contrib import graph_runtime
from tvm import relay, container
from tvm import relay
from tvm.runtime import container
from tvm.relay import testing
from tvm.relay import vm

Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/test_adt.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def tree_to_dict(t):
def vmobj_to_list(o, dtype="float32"):
if isinstance(o, tvm.nd.NDArray):
return [o.asnumpy().tolist()]
elif isinstance(o, tvm.container.ADT):
elif isinstance(o, tvm.runtime.container.ADT):
if len(o) == 0:
tensor_nil = p.get_var("tensor_nil", dtype=dtype)
if tensor_nil.tag == o.tag:
Expand Down
3 changes: 2 additions & 1 deletion tests/python/relay/test_backend_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import tvm
import tvm.testing
from tvm import nd
from tvm import relay, container
from tvm import relay
from tvm.runtime import container
from tvm.relay.backend.interpreter import RefValue, ConstructorValue
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay import testing, create_executor
Expand Down
8 changes: 4 additions & 4 deletions tests/python/relay/test_external_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
import os
import sys
import numpy as np
import pytest

import tvm
import tvm.relay.testing
import tvm.relay.transform
from tvm import relay
from tvm import runtime
from tvm.contrib import util

def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
Expand All @@ -49,11 +49,11 @@ def update_lib(lib):

def check_vm_result():
with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
exe = relay.vm.compile(mod, target=target)
exe = runtime.vm.compile(mod, target=target)
code, lib = exe.save()
lib = update_lib(lib)
exe = relay.vm.Executable.load_exec(code, lib)
vm = relay.vm.VirtualMachine(exe)
exe = runtime.vm.Executable.load_exec(code, lib)
vm = runtime.vm.VirtualMachine(exe)
vm.init(ctx)
out = vm.run(**map_inputs)
tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
Expand Down
9 changes: 5 additions & 4 deletions tests/python/relay/test_pass_partition_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import tvm.relay.testing
import tvm.relay.transform as transform
from tvm import relay
from tvm import runtime
from tvm.contrib import util
from tvm.relay.annotation import compiler_begin, compiler_end
from tvm.relay.expr_functor import ExprMutator
Expand Down Expand Up @@ -182,17 +183,17 @@ def update_lib(lib):
lib_name = 'lib.so'
lib_path = tmp_path.relpath(lib_name)
lib.export_library(lib_path, fcompile=False, **kwargs)
lib = tvm.runtime.load_module(lib_path)
lib = runtime.load_module(lib_path)

return lib

def check_vm_result():
with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
exe = relay.vm.compile(mod, target=target, params=params)
exe = runtime.vm.compile(mod, target=target, params=params)
code, lib = exe.save()
lib = update_lib(lib)
exe = relay.vm.Executable.load_exec(code, lib)
vm = relay.vm.VirtualMachine(exe)
exe = runtime.vm.Executable.load_exec(code, lib)
vm = runtime.vm.VirtualMachine(exe)
vm.init(ctx)
out = vm.run(**map_inputs)
tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
Expand Down
Loading