Skip to content

Commit

Permalink
move fallback out of the build interface (#2456)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics authored and tqchen committed Jan 17, 2019
1 parent 985e7d7 commit b374192
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 22 deletions.
24 changes: 10 additions & 14 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class BuildConfig(object):
defaults = {
"opt_level": 2,
"add_pass": None,
"fallback_device": None,
}

def __init__(self, **kwargs):
Expand Down Expand Up @@ -96,6 +97,10 @@ def build_config(**kwargs):
add_pass: set of str
Optimization pass to be added regardless of optimization level.
fallback_device : str or tvm.TVMContext
The fallback device. It is also used as the default device for
operators without specified device during heterogeneous execution.
Returns
-------
config: BuildConfig
Expand Down Expand Up @@ -192,8 +197,7 @@ def optimize(func, target, params=None):
return func


def build(func, target=None, target_host=None, params=None,
fallback_device=None):
def build(func, target=None, target_host=None, params=None):
"""Build a function to run on TVM graph runtime.
Parameters
Expand All @@ -219,10 +223,6 @@ def build(func, target=None, target_host=None, params=None,
Input parameters to the graph that do not change
during inference time. Used for constant folding.
fallback_device : str or tvm.TVMContext, optional.
The fallback device. It is also used as the default device for
operators with no specified device.
Returns
-------
graph_json : str
Expand All @@ -239,8 +239,7 @@ def build(func, target=None, target_host=None, params=None,
raise ValueError("Target is not set in env or passed as argument.")

if isinstance(target, dict):
target, fallback_device = \
_update_heterogeneous_inputs(target, fallback_device)
target, fallback_device = _update_heterogeneous_inputs(target)
elif isinstance(target, (str, _target.Target)):
target = _target.create(target)
else:
Expand Down Expand Up @@ -277,7 +276,7 @@ def build(func, target=None, target_host=None, params=None,
return graph_json, mod, params


def _update_heterogeneous_inputs(target, fallback_device=None):
def _update_heterogeneous_inputs(target):
"""Update the target and fallback device required for heterogeneous
compilation. CPU is used as the fallback device if it wasn't provided.
Meanwhile, a CPU device type and "llvm" pair will be added to the target
Expand All @@ -288,10 +287,6 @@ def _update_heterogeneous_inputs(target, fallback_device=None):
target : dict of str(i.e. device/context name) to str/tvm.target.Target.
A dict contains context to target pairs.
fallback_device : str or tvm.TVMContext, optional.
The fallback device. It is also used as the default device for
operators with no specified device.
Returns
-------
device_target : dict of int to tvm.target.Target.
Expand All @@ -305,6 +300,7 @@ def _update_heterogeneous_inputs(target, fallback_device=None):
"heterogeneous execution, but received %s."
% type(target))

fallback_device = BuildConfig.current.fallback_device
if fallback_device is None:
# cpu is used as the default fallback device when heterogeneous
# execution is needed, but no fallback device is provided.
Expand All @@ -315,7 +311,7 @@ def _update_heterogeneous_inputs(target, fallback_device=None):
elif isinstance(fallback_device, TVMContext):
fallback_device = fallback_device.device_type
else:
raise ValueError("fallback_device expects the type of str or" +
raise ValueError("fallback_device expects the type of str or " +
"TVMContext, but received %s." % type(fallback_device))

device_target = {}
Expand Down
15 changes: 7 additions & 8 deletions tests/python/relay/test_pass_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import tvm
from tvm import relay
from tvm.relay import testing
from tvm.contrib import graph_runtime


Expand Down Expand Up @@ -248,12 +247,14 @@ def get_func():

def test_runtime(target, device, func, fallback_device=None):
params = {"x": x_data, "y": y_data}
with relay.build_config(opt_level=1):
config = {"opt_level": 1}
if fallback_device:
config["fallback_device"] = fallback_device
with relay.build_config(**config):
graph, lib, params = relay.build(
func,
target,
params=params,
fallback_device=fallback_device)
params=params)
contexts = [tvm.cpu(0), tvm.context(device)]
mod = graph_runtime.create(graph, lib, contexts)
mod.set_input(**params)
Expand Down Expand Up @@ -367,13 +368,11 @@ def annotated():
test_runtime(target, device, annotated_func, fallback_device)

def test_fallback_all_operators(device, tgt):
target = {"cpu": "llvm", device: tgt}
fallback_device = tvm.cpu(0)

target = {device: tgt}
annotated_func = get_func()
expected_func = get_func()
check_annotated_graph(annotated_func, expected_func)
test_runtime(target, device, annotated_func, fallback_device)
test_runtime(target, device, annotated_func)

for dev, tgt in [("opencl", "opencl"), ("cuda", "cuda"),
("opencl", str(tvm.target.intel_graphics()))]:
Expand Down

0 comments on commit b374192

Please sign in to comment.