From 4a78f2efaaacf4a90a499dc22c1f54cec11ba166 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Mon, 7 Oct 2019 19:05:46 -0700 Subject: [PATCH] Refactor --- python/tvm/relay/backend/profiler_vm.py | 19 +----- python/tvm/relay/backend/vm.py | 77 +++++++++++++------------ 2 files changed, 44 insertions(+), 52 deletions(-) diff --git a/python/tvm/relay/backend/profiler_vm.py b/python/tvm/relay/backend/profiler_vm.py index 544722ceecb5f..5aa6d90183766 100644 --- a/python/tvm/relay/backend/profiler_vm.py +++ b/python/tvm/relay/backend/profiler_vm.py @@ -83,25 +83,12 @@ def compile(self, mod, target=None, target_host=None, params=None): The profile VM runtime. """ target = _update_target(target) + target_host = self.update_target_host(target, target_host) + if params: self.set_params(params) - target_host = None if target_host == "" else target_host - if not target_host: - for device_type, tgt in target.items(): - if device_type.value == tvm.nd.cpu(0).device_type: - target_host = tgt - break - if not target_host: - target_host = "llvm" if tvm.module.enabled("llvm") else "stackvm" - target_host = tvm.target.create(target_host) - - # If current dispatch context is fallback context (the default root context), - # then load pre-tuned parameters from TopHub - if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext): - tophub_context = autotvm.tophub.context(list(target.values())) - else: - tophub_context = autotvm.util.EmptyContext() + tophub_context = self.tophub_context(target) with tophub_context: self._compile(mod, target, target_host) diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index 563423ae34719..f4182f15306cb 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -30,26 +30,6 @@ from . import vmobj as _obj from .interpreter import Executor - -def _update_target(target): - target = target if target else tvm.target.current_target() - if target is None: - raise ValueError("Target is not set in env or passed as argument.") - - tgts = {} - if isinstance(target, (str, tvm.target.Target)): - dev_type = tvm.expr.IntImm("int32", tvm.nd.context(str(target)).device_type) - tgts[dev_type] = tvm.target.create(target) - elif isinstance(target, dict): - for dev, tgt in target.items(): - dev_type = tvm.expr.IntImm("int32", tvm.nd.context(dev).device_type) - tgts[dev_type] = tvm.target.create(tgt) - else: - raise TypeError("target is expected to be str, tvm.target.Target, " + - "or dict of str to str/tvm.target.Target, but received " + - "{}".format(type(target))) - return tgts - def _convert(arg, cargs): if isinstance(arg, (np.ndarray, tvm.nd.NDArray)): cargs.append(_obj.tensor_object(arg)) @@ -161,6 +141,44 @@ def set_params(self, params): inputs[name] = _expr.const(param) self._set_params_func(inputs) + def update_target(self, target): + target = target if target else tvm.target.current_target() + if target is None: + raise ValueError("Target is not set in env or passed as argument.") + tgts = {} + if isinstance(target, (str, tvm.target.Target)): + dev_type = tvm.expr.IntImm("int32", tvm.nd.context(str(target)).device_type) + tgts[dev_type] = tvm.target.create(target) + elif isinstance(target, dict): + for dev, tgt in target.items(): + dev_type = tvm.expr.IntImm("int32", tvm.nd.context(dev).device_type) + tgts[dev_type] = tvm.target.create(tgt) + else: + raise TypeError("target is expected to be str, tvm.target.Target, " + + "or dict of str to str/tvm.target.Target, but received " + + "{}".format(type(target))) + return tgts + + def update_target_host(self, target, target_host): + target_host = None if target_host == "" else target_host + if not target_host: + for device_type, tgt in target.items(): + if device_type.value == tvm.nd.cpu(0).device_type: + target_host = tgt + break + if not target_host: + target_host = "llvm" if tvm.module.enabled("llvm") else "stackvm" + return tvm.target.create(target_host) + + def tophub_context(self, target): + # If current dispatch context is fallback context (the default root context), + # then load pre-tuned parameters from TopHub + if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext): + tophub_context = autotvm.tophub.context(list(target.values())) + else: + tophub_context = autotvm.util.EmptyContext() + return tophub_context + def compile(self, mod, target=None, target_host=None, params=None): """ Parameters @@ -191,26 +209,13 @@ def compile(self, mod, target=None, target_host=None, params=None): vm : VirtualMachine The VM runtime. """ - target = _update_target(target) - target_host = None if target_host == "" else target_host - if not target_host: - for device_type, tgt in target.items(): - if device_type.value == tvm.nd.cpu(0).device_type: - target_host = tgt - break - if not target_host: - target_host = "llvm" if tvm.module.enabled("llvm") else "stackvm" - target_host = tvm.target.create(target_host) + target = self.update_target(target) + target_host = self.update_target_host(target, target_host) if params: self.set_params(params) - # If current dispatch context is fallback context (the default root context), - # then load pre-tuned parameters from TopHub - if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext): - tophub_context = autotvm.tophub.context(list(target.values())) - else: - tophub_context = autotvm.util.EmptyContext() + tophub_context = self.tophub_context(target) with tophub_context: self._compile(mod, target, target_host)