Skip to content

Commit

Permalink
Add type hint. (apache#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
YuchenJin authored and yongwww committed Aug 14, 2022
1 parent 3d11bab commit 9c5ae33
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 26 deletions.
34 changes: 21 additions & 13 deletions python/tvm/relax/exec_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
# under the License.

from enum import IntEnum
from typing import Optional, Union, List
import tvm
from tvm._ffi._ctypes.packed_func import TVMRetValueHandle
from tvm.runtime import Object
from tvm._ffi.base import _LIB, check_call
from . vm import Executable
from . import _ffi_api

class SpecialReg(IntEnum):
Expand All @@ -40,40 +43,46 @@ def __exit__(self, ptype, value, trace):
@tvm._ffi.register_object("relax.ExecBuilder")
class ExecBuilder(Object):
"""A builder to emit instructions and build executable for the virtual machine."""
def __init__(self):

def __init__(self) -> None:
self.__init_handle_by_constructor__(_ffi_api.ExecBuilderCreate)

def r(self, idx):
def r(self, idx: int) -> int:
"""set instruction's argument as a register."""
return _ffi_api.ExecBuilderR(self, idx)

def imm(self, value):
def imm(self, value: int) -> int:
"""set instruction's argument as an immediate."""
return _ffi_api.ExecBuilderImm(self, value)

def c(self, idx):
def c(self, idx: int) -> int:
"""set instruction's argument as a constant."""
return _ffi_api.ExecBuilderC(self, idx)

def void_arg(self):
def void_arg(self) -> int:
return self.r(SpecialReg.VOID_ARG)

def vm_state(self):
def vm_state(self) -> int:
return self.r(SpecialReg.VM_STATE)

def function(self, func_name, num_inputs=0):
def function(self, func_name: str, num_inputs: Optional[int] = 0) -> VMFuncScope:
"""annotate a VM function."""
_ffi_api.ExecBuilderFunction(self, func_name, num_inputs)
return VMFuncScope()

def _check_scope(self):
def _check_scope(self) -> None:
if len(VMFuncScope.stack) == 0:
raise ValueError("emit should happen in a function scope")

def emit_constant(self, const):
def emit_constant(self, const: TVMRetValueHandle) -> int:
return _ffi_api.ExecBuilderEmitConstant(self, const)

def emit_call(self, name, args=[], dst=None):
def emit_call(
self,
name: str,
args: Optional[List[Union[tvm.nd.NDArray, tvm.DataType]]] = [],
dst: int = None,
) -> None:
"""emit a call instruction which calls a packed function."""
self._check_scope()
if dst is None:
Expand All @@ -87,12 +96,11 @@ def emit_call(self, name, args=[], dst=None):
args_.append(arg)
_ffi_api.ExecBuilderEmitCall(self, name, args_, dst)

def emit_ret(self, result):
def emit_ret(self, result: int) -> None:
"""emit a return instruction"""
self._check_scope()
_ffi_api.ExecBuilderEmitRet(self, result)

def get(self):
def get(self) -> Executable:
"""return the executable"""
return _ffi_api.ExecBuilderGet(self)

38 changes: 26 additions & 12 deletions python/tvm/relax/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
# specific language governing permissions and limitations
# under the License.

from typing import List, Optional, Union, Dict
import tvm
from tvm.runtime import Object
from tvm.runtime import Object, Device, Module, PackedFunc
from tvm._ffi.base import _LIB, check_call
from . import _ffi_api
from ..rpc.base import RPC_SESS_MASK
Expand All @@ -25,35 +26,45 @@
@tvm._ffi.register_object("relax.Executable")
class Executable(Object):
"""The executable object emitted by the VM compiler or the ExecBuilder."""

def __init__(self):
self.__init_handle_by_constructor__(_ffi_api.Executable)

def stats(self):
def stats(self) -> str:
"""print the detailed statistics of the executable."""
return _ffi_api.ExecutableStats(self)

def save_to_file(self, file_name):
def save_to_file(self, file_name: str) -> None:
"""serialize and write the executable to a file."""
return _ffi_api.ExecutableSaveToFile(self, file_name)
_ffi_api.ExecutableSaveToFile(self, file_name)

def astext(self):
def astext(self) -> str:
"""print the instructions as text format."""
return _ffi_api.ExecutableAsText(self)
def aspython(self):

def aspython(self) -> str:
"""print the instructions as python program."""
return _ffi_api.ExecutableAsPython(self)

def load_exec_from_file(file_name):

def load_exec_from_file(file_name: str) -> Executable:
return _ffi_api.ExecutableLoadFromFile(file_name)


class VirtualMachine(object):
"""Relax VM runtime."""

NAIVE_ALLOCATOR = 1
POOLED_ALLOCATOR = 2

def __init__(self, exec, device, memory_cfg=None, mod=None):

def __init__(
self,
exec: Executable,
device: Union[Device, List[Device]],
memory_cfg: Optional[Union[str, Dict[Device, str]]] = None,
mod: Optional[Module] = None,
) -> None:

"""
Construct a VirtualMachine wrapper object.
Expand All @@ -73,6 +84,9 @@ def __init__(self, exec, device, memory_cfg=None, mod=None):
type specified in the dict, or pooled allocator if not specified in the
dict.
mod : tvm.runtime.Module, optional
Optional runtime module to load to the VM.
Returns
-------
vm: VirtualMachine
Expand All @@ -81,7 +95,7 @@ def __init__(self, exec, device, memory_cfg=None, mod=None):
self.module = _ffi_api.VirtualMachine(exec, mod)
self._setup_device(device, memory_cfg)

def _setup_device(self, dev, memory_cfg):
def _setup_device(self, dev: Device, memory_cfg: Union[str, Dict[Device, str]]) -> None:
"""init devices and allocators."""
devs = dev
if not isinstance(dev, (list, tuple)):
Expand Down Expand Up @@ -117,5 +131,5 @@ def _setup_device(self, dev, memory_cfg):
init_args.append(alloc_type)
_ffi_api.VirtualMachineInit(self.module, *init_args)

def __getitem__(self, key):
def __getitem__(self, key: str) -> PackedFunc:
return self.module[key]
2 changes: 1 addition & 1 deletion src/relax/vm/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ TVM_REGISTER_GLOBAL("relax.ExecutableAsPython").set_body_typed([](Executable exe

TVM_REGISTER_GLOBAL("relax.ExecutableSaveToFile")
.set_body_typed([](Executable exec, std::string file_name) {
return exec->SaveToFile(file_name);
exec->SaveToFile(file_name);
});

TVM_REGISTER_GLOBAL("relax.ExecutableLoadFromFile").set_body_typed([](std::string file_name) {
Expand Down

0 comments on commit 9c5ae33

Please sign in to comment.