Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
wweic committed Oct 8, 2019
1 parent 48c046a commit 4a78f2e
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 52 deletions.
19 changes: 3 additions & 16 deletions python/tvm/relay/backend/profiler_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,25 +83,12 @@ def compile(self, mod, target=None, target_host=None, params=None):
The profile VM runtime.
"""
target = _update_target(target)
target_host = self.update_target_host(target, target_host)

if params:
self.set_params(params)

target_host = None if target_host == "" else target_host
if not target_host:
for device_type, tgt in target.items():
if device_type.value == tvm.nd.cpu(0).device_type:
target_host = tgt
break
if not target_host:
target_host = "llvm" if tvm.module.enabled("llvm") else "stackvm"
target_host = tvm.target.create(target_host)

# If current dispatch context is fallback context (the default root context),
# then load pre-tuned parameters from TopHub
if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
tophub_context = autotvm.tophub.context(list(target.values()))
else:
tophub_context = autotvm.util.EmptyContext()
tophub_context = self.tophub_context(target)

with tophub_context:
self._compile(mod, target, target_host)
Expand Down
77 changes: 41 additions & 36 deletions python/tvm/relay/backend/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,6 @@
from . import vmobj as _obj
from .interpreter import Executor


def _update_target(target):
target = target if target else tvm.target.current_target()
if target is None:
raise ValueError("Target is not set in env or passed as argument.")

tgts = {}
if isinstance(target, (str, tvm.target.Target)):
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(str(target)).device_type)
tgts[dev_type] = tvm.target.create(target)
elif isinstance(target, dict):
for dev, tgt in target.items():
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(dev).device_type)
tgts[dev_type] = tvm.target.create(tgt)
else:
raise TypeError("target is expected to be str, tvm.target.Target, " +
"or dict of str to str/tvm.target.Target, but received " +
"{}".format(type(target)))
return tgts

def _convert(arg, cargs):
if isinstance(arg, (np.ndarray, tvm.nd.NDArray)):
cargs.append(_obj.tensor_object(arg))
Expand Down Expand Up @@ -161,6 +141,44 @@ def set_params(self, params):
inputs[name] = _expr.const(param)
self._set_params_func(inputs)

def update_target(self, target):
target = target if target else tvm.target.current_target()
if target is None:
raise ValueError("Target is not set in env or passed as argument.")
tgts = {}
if isinstance(target, (str, tvm.target.Target)):
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(str(target)).device_type)
tgts[dev_type] = tvm.target.create(target)
elif isinstance(target, dict):
for dev, tgt in target.items():
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(dev).device_type)
tgts[dev_type] = tvm.target.create(tgt)
else:
raise TypeError("target is expected to be str, tvm.target.Target, " +
"or dict of str to str/tvm.target.Target, but received " +
"{}".format(type(target)))
return tgts

def update_target_host(self, target, target_host):
target_host = None if target_host == "" else target_host
if not target_host:
for device_type, tgt in target.items():
if device_type.value == tvm.nd.cpu(0).device_type:
target_host = tgt
break
if not target_host:
target_host = "llvm" if tvm.module.enabled("llvm") else "stackvm"
return tvm.target.create(target_host)

def tophub_context(self, target):
# If current dispatch context is fallback context (the default root context),
# then load pre-tuned parameters from TopHub
if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
tophub_context = autotvm.tophub.context(list(target.values()))
else:
tophub_context = autotvm.util.EmptyContext()
return tophub_context

def compile(self, mod, target=None, target_host=None, params=None):
"""
Parameters
Expand Down Expand Up @@ -191,26 +209,13 @@ def compile(self, mod, target=None, target_host=None, params=None):
vm : VirtualMachine
The VM runtime.
"""
target = _update_target(target)
target_host = None if target_host == "" else target_host
if not target_host:
for device_type, tgt in target.items():
if device_type.value == tvm.nd.cpu(0).device_type:
target_host = tgt
break
if not target_host:
target_host = "llvm" if tvm.module.enabled("llvm") else "stackvm"
target_host = tvm.target.create(target_host)
target = self.update_target(target)
target_host = self.update_target_host(target, target_host)

if params:
self.set_params(params)

# If current dispatch context is fallback context (the default root context),
# then load pre-tuned parameters from TopHub
if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
tophub_context = autotvm.tophub.context(list(target.values()))
else:
tophub_context = autotvm.util.EmptyContext()
tophub_context = self.tophub_context(target)

with tophub_context:
self._compile(mod, target, target_host)
Expand Down

0 comments on commit 4a78f2e

Please sign in to comment.