diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index df2cc105bc68..51a4ff873e0a 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -36,6 +36,7 @@ class BuildConfig(object): defaults = { "opt_level": 2, "add_pass": None, + "fallback_device": None, } def __init__(self, **kwargs): @@ -96,6 +97,10 @@ def build_config(**kwargs): add_pass: set of str Optimization pass to be added regardless of optimization level. + fallback_device : str or tvm.TVMContext + The fallback device. It is also used as the default device for + operators without specified device during heterogeneous execution. + Returns ------- config: BuildConfig @@ -192,8 +197,7 @@ def optimize(func, target, params=None): return func -def build(func, target=None, target_host=None, params=None, - fallback_device=None): +def build(func, target=None, target_host=None, params=None): """Build a function to run on TVM graph runtime. Parameters @@ -219,10 +223,6 @@ def build(func, target=None, target_host=None, params=None, Input parameters to the graph that do not change during inference time. Used for constant folding. - fallback_device : str or tvm.TVMContext, optional. - The fallback device. It is also used as the default device for - operators with no specified device. - Returns ------- graph_json : str @@ -239,8 +239,7 @@ def build(func, target=None, target_host=None, params=None, raise ValueError("Target is not set in env or passed as argument.") if isinstance(target, dict): - target, fallback_device = \ - _update_heterogeneous_inputs(target, fallback_device) + target, fallback_device = _update_heterogeneous_inputs(target) elif isinstance(target, (str, _target.Target)): target = _target.create(target) else: @@ -277,7 +276,7 @@ def build(func, target=None, target_host=None, params=None, return graph_json, mod, params -def _update_heterogeneous_inputs(target, fallback_device=None): +def _update_heterogeneous_inputs(target): """Update the target and fallback device required for heterogeneous compilation. CPU is used as the fallback device if it wasn't provided. Meanwhile, a CPU device type and "llvm" pair will be added to the target @@ -288,10 +287,6 @@ def _update_heterogeneous_inputs(target, fallback_device=None): target : dict of str(i.e. device/context name) to str/tvm.target.Target. A dict contains context to target pairs. - fallback_device : str or tvm.TVMContext, optional. - The fallback device. It is also used as the default device for - operators with no specified device. - Returns ------- device_target : dict of int to tvm.target.Target. @@ -305,6 +300,7 @@ def _update_heterogeneous_inputs(target, fallback_device=None): "heterogeneous execution, but received %s." % type(target)) + fallback_device = BuildConfig.current.fallback_device if fallback_device is None: # cpu is used as the default fallback device when heterogeneous # execution is needed, but no fallback device is provided. @@ -315,7 +311,7 @@ def _update_heterogeneous_inputs(target, fallback_device=None): elif isinstance(fallback_device, TVMContext): fallback_device = fallback_device.device_type else: - raise ValueError("fallback_device expects the type of str or" + + raise ValueError("fallback_device expects the type of str or " + "TVMContext, but received %s." % type(fallback_device)) device_target = {} diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index 1808ecb818a8..9f54a9fa949f 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -3,7 +3,6 @@ import tvm from tvm import relay -from tvm.relay import testing from tvm.contrib import graph_runtime @@ -248,12 +247,14 @@ def get_func(): def test_runtime(target, device, func, fallback_device=None): params = {"x": x_data, "y": y_data} - with relay.build_config(opt_level=1): + config = {"opt_level": 1} + if fallback_device: + config["fallback_device"] = fallback_device + with relay.build_config(**config): graph, lib, params = relay.build( func, target, - params=params, - fallback_device=fallback_device) + params=params) contexts = [tvm.cpu(0), tvm.context(device)] mod = graph_runtime.create(graph, lib, contexts) mod.set_input(**params) @@ -367,13 +368,11 @@ def annotated(): test_runtime(target, device, annotated_func, fallback_device) def test_fallback_all_operators(device, tgt): - target = {"cpu": "llvm", device: tgt} - fallback_device = tvm.cpu(0) - + target = {device: tgt} annotated_func = get_func() expected_func = get_func() check_annotated_graph(annotated_func, expected_func) - test_runtime(target, device, annotated_func, fallback_device) + test_runtime(target, device, annotated_func) for dev, tgt in [("opencl", "opencl"), ("cuda", "cuda"), ("opencl", str(tvm.target.intel_graphics()))]: