diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 3c4aeea5af67..03ad23ef6a62 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -200,7 +200,7 @@ def __init__(self, class LocalRunner(ProgramRunner): """ LocalRunner that uses local CPU/GPU to measures the time cost of programs. - TODO(FrozenGene): Add cpu cache flush to this runner + TODO(FrozenGene): Add cpu cache flush to this runner. Parameters ---------- @@ -237,7 +237,6 @@ def __init__(self, _ffi_api.LocalRunner, timeout, number, repeat, min_repeat_ms, cooldown_interval) - @tvm._ffi.register_object("auto_scheduler.RPCRunner") class RPCRunner(ProgramRunner): """ RPCRunner that uses RPC call to measures the time cost of programs on remote devices. @@ -433,7 +432,8 @@ def timed_func(): dirname, "tmp_func." + build_func.output_format) try: - with transform.PassContext(): # todo(lmzheng): port the unroll pass + # TODO(merrymercy): Port the unroll pass. + with transform.PassContext(): func = build_module.build( sch, args, target=task.target, target_host=task.target_host) func.export_library(filename, build_func) @@ -565,7 +565,7 @@ def timed_func(inp, build_res): if error_no == 0: try: - # TODO(FrozenGene): Update to ndarray.non-empty + # TODO(FrozenGene): Update to ndarray.non-empty. args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args] ctx.sync() @@ -656,7 +656,7 @@ def timed_func(): if error_no == 0: try: - # TODO(FrozenGene): Update to ndarray.non-empty + # TODO(FrozenGene): Update to ndarray.non-empty. args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args] ctx.sync() diff --git a/python/tvm/auto_scheduler/utils.py b/python/tvm/auto_scheduler/utils.py index f34db44fc4b0..f5b53fb2a446 100644 --- a/python/tvm/auto_scheduler/utils.py +++ b/python/tvm/auto_scheduler/utils.py @@ -199,26 +199,27 @@ def func_wrapper(que): def request_remote(device_key, host=None, port=None, priority=1, timeout=60): - """Request a remote session + """ Request a remote session. Parameters ---------- device_key : str - The device key of registered device in tracker + The device key of registered device in tracker. host : Optional[str] The host address of rpc tracker. - If is none, will use environment variable "TVM_TRACKER_HOST" + If is none, will use environment variable "TVM_TRACKER_HOST". port : Optional[int] The port of rpc tracker. - If is none, will use environment variable "TVM_TRACKER_PORT" + If is none, will use environment variable "TVM_TRACKER_PORT". priority : int = 1 - The priority of this request, larger is more prior + The priority of this request, larger is more prior. timeout : int = 60 - The timeout of this session (units: second) + The timeout of this session in second. Returns ------- - session : RPCSession + remote : RPCSession + The connected remote RPCSession. """ # connect to the tracker host = host or os.environ['TVM_TRACKER_HOST'] @@ -232,27 +233,27 @@ def request_remote(device_key, host=None, port=None, priority=1, timeout=60): def check_remote(device_key, host=None, port=None, priority=100, timeout=10): """ - Check the availability of a remote device + Check the availability of a remote device. Parameters ---------- device_key: str - device key of registered device in tracker + device key of registered device in tracker. host: Optional[str] The host address of rpc tracker. - If is none, will use environment variable "TVM_TRACKER_HOST" + If is none, will use environment variable "TVM_TRACKER_HOST". port: Optional[int] The port address of rpc tracker. - If is none, will use environment variable "TVM_TRACKER_PORT" + If is none, will use environment variable "TVM_TRACKER_PORT". priority: int = 100 - The priority of this request, larger is more prior + The priority of this request, larger is more prior. timeout: int = 10 - The timeout of this check (units: seconds). + The timeout of this check in seconds. Returns ------- available: bool - True if can find available device + True if can find available device. """ def _check():