Skip to content

Commit

Permalink
[REFACTOR][RELAY] move fallback_device to config (apache#5690)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics authored and Trevor Morris committed Jun 18, 2020
1 parent 40d66b1 commit f3368a7
Show file tree
Hide file tree
Showing 8 changed files with 20 additions and 43 deletions.
5 changes: 0 additions & 5 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,6 @@ class PassContextNode : public Object {
/*! \brief The default optimization level. */
int opt_level{2};

/*! \brief CPU is the default fallback device for heterogeneous execution. */
int fallback_device{static_cast<int>(kDLCPU)};

/*! \brief The list of required passes. */
Array<String> required_pass;
/*! \brief The list of disabled passes. */
Expand Down Expand Up @@ -139,7 +136,6 @@ class PassContextNode : public Object {

void VisitAttrs(AttrVisitor* v) {
v->Visit("opt_level", &opt_level);
v->Visit("fallback_device", &fallback_device);
v->Visit("required_pass", &required_pass);
v->Visit("disabled_pass", &disabled_pass);
v->Visit("config", &config);
Expand All @@ -157,7 +153,6 @@ class PassContextNode : public Object {
*
* auto new_ctx = PassContext::Create();
* ctx->opt_level = 2;
* ctx->fallback_device = kDLCPU;
* With<PassContext> scope(ctx);
* // pass context in effect.
*
Expand Down
18 changes: 1 addition & 17 deletions python/tvm/ir/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@
import functools

import tvm._ffi

import tvm.runtime
from tvm.runtime import ndarray as _nd

from . import _ffi_transform_api

Expand Down Expand Up @@ -61,10 +59,6 @@ class PassContext(tvm.runtime.Object):
opt_level : Optional[int]
The optimization level of this pass.
fallback_device : Optional[Union[int, str, TVMContext]]
The fallback device type. It is also used as the default device for
operators that are not annotated during heterogeneous execution.
required_pass : Optional[Union[List[str], Set[str], Tuple[str]]]
The list of passes that are required by a certain pass.
Expand All @@ -76,19 +70,10 @@ class PassContext(tvm.runtime.Object):
"""
def __init__(self,
opt_level=2,
fallback_device=_nd.cpu(),
required_pass=None,
disabled_pass=None,
trace=None,
config=None):
if isinstance(fallback_device, str):
fallback_device = _nd.context(fallback_device).device_type
elif isinstance(fallback_device, tvm.runtime.TVMContext):
fallback_device = fallback_device.device_type
if not isinstance(fallback_device, int):
raise TypeError("fallback_device is expected to be the type of " +
"int/str/TVMContext.")

required = list(required_pass) if required_pass else []
if not isinstance(required, (list, tuple)):
raise TypeError("required_pass is expected to be the type of " +
Expand All @@ -101,8 +86,7 @@ def __init__(self,

config = config if config else None
self.__init_handle_by_constructor__(_ffi_transform_api.PassContext, opt_level,
fallback_device, required,
disabled, trace, config)
required, disabled, trace, config)

def __enter__(self):
_ffi_transform_api.EnterPassContext(self)
Expand Down
10 changes: 2 additions & 8 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@


def build_config(opt_level=2,
fallback_device=_nd.cpu(),
required_pass=None,
disabled_pass=None,
trace=None):
Expand Down Expand Up @@ -59,10 +58,6 @@ def build_config(opt_level=2,
"FastMath": 4
}
fallback_device : int, str, or tvmContext, optional
The fallback device. It is also used as the default device for
operators without specified device during heterogeneous execution.
required_pass: set of str, optional
Optimization passes that are required regardless of optimization level.
Expand All @@ -77,9 +72,8 @@ def build_config(opt_level=2,
pass_context: PassContext
The pass context for optimizations.
"""
return tvm.ir.transform.PassContext(
opt_level, fallback_device, required_pass,
disabled_pass, trace)
return tvm.ir.transform.PassContext(opt_level, required_pass,
disabled_pass, trace)


@tvm._ffi.register_object("relay.FunctionPass")
Expand Down
7 changes: 2 additions & 5 deletions src/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -454,12 +454,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
TVM_REGISTER_NODE_TYPE(PassContextNode);

TVM_REGISTER_GLOBAL("transform.PassContext")
.set_body_typed([](int opt_level, int fallback_device, Array<String> required,
Array<String> disabled, TraceFunc trace_func,
Optional<Map<std::string, ObjectRef>> config) {
.set_body_typed([](int opt_level, Array<String> required, Array<String> disabled,
TraceFunc trace_func, Optional<Map<std::string, ObjectRef>> config) {
auto pctx = PassContext::Create();
pctx->opt_level = opt_level;
pctx->fallback_device = fallback_device;

pctx->required_pass = std::move(required);
pctx->disabled_pass = std::move(disabled);
Expand All @@ -477,7 +475,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "Pass context information: "
<< "\n";
p->stream << "\topt_level: " << node->opt_level << "\n";
p->stream << "\tfallback device: " << runtime::DeviceName(node->fallback_device) << "\n";

p->stream << "\trequired passes: [";
for (const auto& it : node->required_pass) {
Expand Down
7 changes: 6 additions & 1 deletion src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,12 @@ class RelayBuildModule : public runtime::ModuleNode {
// Handle heterogeneous compilation.
transform::PassContext pass_ctx = PassContext::Current();
if (targets_.size() > 1) {
relay_module = RunDeviceAnnotationPass(relay_module, pass_ctx->fallback_device);
Optional<IntImm> opt_fallback_dev =
pass_ctx->GetConfig("relay.fallback_device_type",
IntImm(runtime::DataType::Int(32), static_cast<int>(kDLCPU)));
auto fallback_dev = opt_fallback_dev.value();
CHECK_GT(fallback_dev->value, 0U);
relay_module = RunDeviceAnnotationPass(relay_module, fallback_dev->value);
}

// Fuse the operations if it is needed.
Expand Down
2 changes: 2 additions & 0 deletions src/relay/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ namespace tvm {
namespace relay {
namespace transform {

TVM_REGISTER_PASS_CONFIG_OPTION("relay.fallback_device_type", IntImm);

class FunctionPass;

/*!
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/relay_transform_sequential.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ TEST(Relay, Sequential) {
auto mod = IRModule::FromExpr(func);
auto pass_ctx = relay::transform::PassContext::Create();
pass_ctx->opt_level = 3;
pass_ctx->fallback_device = 1;
pass_ctx->config.Set("relay.fallback_device_type", IntImm(DataType::Int(32), 1));
{
tvm::With<relay::transform::PassContext> ctx_scope(pass_ctx);
tvm::With<tvm::Target> tctx(tvm::Target::Create("llvm"));
Expand Down
12 changes: 6 additions & 6 deletions tests/python/relay/test_pass_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,10 +344,10 @@ def get_func():
def test_runtime(target, device, func, fallback_device=None,
expected_index=None):
params = {"x": x_data, "y": y_data}
config = {"opt_level": 1}
config = {}
if fallback_device:
config["fallback_device"] = fallback_device
with relay.build_config(**config):
config["relay.fallback_device_type"] = fallback_device.device_type
with tvm.transform.PassContext(opt_level=1, config=config):
graph, lib, params = relay.build(
func,
target,
Expand Down Expand Up @@ -538,9 +538,9 @@ def expected():
expected_index = [2, 2, 2, 1, 1, 1, 2, 2]
check_annotated_graph(annotated_func, expected_func)
params = {"a": a_data, "b": b_data, "c": c_data, "d": d_data}
config = {"opt_level": 0}
config["fallback_device"] = fallback_device
with relay.build_config(**config):
with tvm.transform.PassContext(opt_level=0,
config={"relay.fallback_device_type":
fallback_device.device_type}):
graph, lib, params = relay.build(annotated_func, target, params=params)
contexts = [tvm.cpu(0), tvm.context(dev)]
graph_json = json.loads(graph)
Expand Down

0 comments on commit f3368a7

Please sign in to comment.