From 9c5ae33e18227d211cad04476fbeaa45b605cdd1 Mon Sep 17 00:00:00 2001 From: Yuchen Jin Date: Wed, 6 Oct 2021 15:35:05 -0700 Subject: [PATCH] Add type hint. (#20) --- python/tvm/relax/exec_builder.py | 34 +++++++++++++++++----------- python/tvm/relax/vm.py | 38 ++++++++++++++++++++++---------- src/relax/vm/executable.cc | 2 +- 3 files changed, 48 insertions(+), 26 deletions(-) diff --git a/python/tvm/relax/exec_builder.py b/python/tvm/relax/exec_builder.py index a8a834a1602c..bb7c2458c797 100644 --- a/python/tvm/relax/exec_builder.py +++ b/python/tvm/relax/exec_builder.py @@ -16,9 +16,12 @@ # under the License. from enum import IntEnum +from typing import Optional, Union, List import tvm +from tvm._ffi._ctypes.packed_func import TVMRetValueHandle from tvm.runtime import Object from tvm._ffi.base import _LIB, check_call +from . vm import Executable from . import _ffi_api class SpecialReg(IntEnum): @@ -40,40 +43,46 @@ def __exit__(self, ptype, value, trace): @tvm._ffi.register_object("relax.ExecBuilder") class ExecBuilder(Object): """A builder to emit instructions and build executable for the virtual machine.""" - def __init__(self): + + def __init__(self) -> None: self.__init_handle_by_constructor__(_ffi_api.ExecBuilderCreate) - def r(self, idx): + def r(self, idx: int) -> int: """set instruction's argument as a register.""" return _ffi_api.ExecBuilderR(self, idx) - def imm(self, value): + def imm(self, value: int) -> int: """set instruction's argument as an immediate.""" return _ffi_api.ExecBuilderImm(self, value) - def c(self, idx): + def c(self, idx: int) -> int: """set instruction's argument as a constant.""" return _ffi_api.ExecBuilderC(self, idx) - def void_arg(self): + def void_arg(self) -> int: return self.r(SpecialReg.VOID_ARG) - def vm_state(self): + def vm_state(self) -> int: return self.r(SpecialReg.VM_STATE) - def function(self, func_name, num_inputs=0): + def function(self, func_name: str, num_inputs: Optional[int] = 0) -> VMFuncScope: """annotate a VM function.""" _ffi_api.ExecBuilderFunction(self, func_name, num_inputs) return VMFuncScope() - def _check_scope(self): + def _check_scope(self) -> None: if len(VMFuncScope.stack) == 0: raise ValueError("emit should happen in a function scope") - def emit_constant(self, const): + def emit_constant(self, const: TVMRetValueHandle) -> int: return _ffi_api.ExecBuilderEmitConstant(self, const) - def emit_call(self, name, args=[], dst=None): + def emit_call( + self, + name: str, + args: Optional[List[Union[tvm.nd.NDArray, tvm.DataType]]] = [], + dst: int = None, + ) -> None: """emit a call instruction which calls a packed function.""" self._check_scope() if dst is None: @@ -87,12 +96,11 @@ def emit_call(self, name, args=[], dst=None): args_.append(arg) _ffi_api.ExecBuilderEmitCall(self, name, args_, dst) - def emit_ret(self, result): + def emit_ret(self, result: int) -> None: """emit a return instruction""" self._check_scope() _ffi_api.ExecBuilderEmitRet(self, result) - def get(self): + def get(self) -> Executable: """return the executable""" return _ffi_api.ExecBuilderGet(self) - diff --git a/python/tvm/relax/vm.py b/python/tvm/relax/vm.py index ed8a1884230d..08237d51f850 100644 --- a/python/tvm/relax/vm.py +++ b/python/tvm/relax/vm.py @@ -15,8 +15,9 @@ # specific language governing permissions and limitations # under the License. +from typing import List, Optional, Union, Dict import tvm -from tvm.runtime import Object +from tvm.runtime import Object, Device, Module, PackedFunc from tvm._ffi.base import _LIB, check_call from . import _ffi_api from ..rpc.base import RPC_SESS_MASK @@ -25,35 +26,45 @@ @tvm._ffi.register_object("relax.Executable") class Executable(Object): """The executable object emitted by the VM compiler or the ExecBuilder.""" + def __init__(self): self.__init_handle_by_constructor__(_ffi_api.Executable) - def stats(self): + def stats(self) -> str: """print the detailed statistics of the executable.""" return _ffi_api.ExecutableStats(self) - def save_to_file(self, file_name): + def save_to_file(self, file_name: str) -> None: """serialize and write the executable to a file.""" - return _ffi_api.ExecutableSaveToFile(self, file_name) + _ffi_api.ExecutableSaveToFile(self, file_name) - def astext(self): + def astext(self) -> str: """print the instructions as text format.""" return _ffi_api.ExecutableAsText(self) - - def aspython(self): + + def aspython(self) -> str: """print the instructions as python program.""" return _ffi_api.ExecutableAsPython(self) -def load_exec_from_file(file_name): + +def load_exec_from_file(file_name: str) -> Executable: return _ffi_api.ExecutableLoadFromFile(file_name) + class VirtualMachine(object): """Relax VM runtime.""" NAIVE_ALLOCATOR = 1 POOLED_ALLOCATOR = 2 - - def __init__(self, exec, device, memory_cfg=None, mod=None): + + def __init__( + self, + exec: Executable, + device: Union[Device, List[Device]], + memory_cfg: Optional[Union[str, Dict[Device, str]]] = None, + mod: Optional[Module] = None, + ) -> None: + """ Construct a VirtualMachine wrapper object. @@ -73,6 +84,9 @@ def __init__(self, exec, device, memory_cfg=None, mod=None): type specified in the dict, or pooled allocator if not specified in the dict. + mod : tvm.runtime.Module, optional + Optional runtime module to load to the VM. + Returns ------- vm: VirtualMachine @@ -81,7 +95,7 @@ def __init__(self, exec, device, memory_cfg=None, mod=None): self.module = _ffi_api.VirtualMachine(exec, mod) self._setup_device(device, memory_cfg) - def _setup_device(self, dev, memory_cfg): + def _setup_device(self, dev: Device, memory_cfg: Union[str, Dict[Device, str]]) -> None: """init devices and allocators.""" devs = dev if not isinstance(dev, (list, tuple)): @@ -117,5 +131,5 @@ def _setup_device(self, dev, memory_cfg): init_args.append(alloc_type) _ffi_api.VirtualMachineInit(self.module, *init_args) - def __getitem__(self, key): + def __getitem__(self, key: str) -> PackedFunc: return self.module[key] diff --git a/src/relax/vm/executable.cc b/src/relax/vm/executable.cc index 862005598716..1154f43ab249 100644 --- a/src/relax/vm/executable.cc +++ b/src/relax/vm/executable.cc @@ -462,7 +462,7 @@ TVM_REGISTER_GLOBAL("relax.ExecutableAsPython").set_body_typed([](Executable exe TVM_REGISTER_GLOBAL("relax.ExecutableSaveToFile") .set_body_typed([](Executable exec, std::string file_name) { - return exec->SaveToFile(file_name); + exec->SaveToFile(file_name); }); TVM_REGISTER_GLOBAL("relax.ExecutableLoadFromFile").set_body_typed([](std::string file_name) {