Skip to content

Commit

Permalink
[MetaSchedule] Use current pass context in compile_relay, extract_tasks
Browse files Browse the repository at this point in the history
Adds the pass config information necessary for tuning and compiling
relay with metaschedule to the existing pass context instead of
overriding the existing one. Allows users to pass in their own pass
instruments, required passes, and disabled passes. This also keeps the
same API used to compile relay with autotvm and auto_scheduler.
  • Loading branch information
Tristan Konolige committed Jan 3, 2023
1 parent 6dbb7e1 commit b26d443
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 59 deletions.
75 changes: 21 additions & 54 deletions python/tvm/meta_schedule/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,8 @@ def _normalize_params(
mod: IRModule,
target: Union[Target, str],
params: Optional[Dict[str, NDArray]],
pass_config: Mapping[str, Any],
executor: Optional["relay.backend.Executor"],
) -> Tuple[
IRModule,
Target,
Dict[str, NDArray],
Dict[str, Any],
Optional["relay.backend.Executor"],
]:
) -> Tuple[IRModule, Target, Dict[str, NDArray], Optional["relay.backend.Executor"],]:
from tvm import relay # pylint: disable=import-outside-toplevel

if isinstance(mod, relay.Function):
Expand All @@ -102,25 +95,16 @@ def _normalize_params(
else:
executor = mod.get_attr("executor")

pass_config = dict(pass_config)
return mod, target, relay_params, pass_config, executor
return mod, target, relay_params, executor


def extract_tasks(
mod: IRModule,
target: Union[Target, str],
params: Optional[Dict[str, NDArray]],
*,
opt_level: int = 3,
pass_config: Mapping[str, Any] = MappingProxyType(
{
"relay.backend.use_meta_schedule": True,
"relay.backend.tir_converter": "default",
}
),
executor: Optional["relay.backend.Executor"] = None,
module_equality: str = "structural",
disabled_pass: Optional[Union[List[str], Set[str], Tuple[str]]] = None,
) -> List[ExtractedTask]:
"""Extract tuning tasks from a relay program.
Expand All @@ -132,10 +116,6 @@ def extract_tasks(
The compilation target
params : Optional[Dict[str, tvm.runtime.NDArray]]
The associated parameters of the program
opt_level : int
The optimization level of the compilation
pass_config : Mapping[str, Any]
The pass configuration
executor : Optional[relay.backend.Executor]
The executor to use
module_equality : Optional[str]
Expand All @@ -148,8 +128,6 @@ def extract_tasks(
given module. The "ignore-ndarray" varint is used for the extracted
blocks or in case no anchor block is found.
For the definition of the anchor block, see tir/analysis/analysis.py.
disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]]
The list of disabled passes
Returns
-------
Expand All @@ -160,21 +138,25 @@ def extract_tasks(
from tvm import autotvm

# pylint: enable=import-outside-toplevel
mod, target, params, pass_config, _ = _normalize_params(
mod, target, params, pass_config, executor
)
mod, target, params, _ = _normalize_params(mod, target, params, executor)
if target.kind.name != "cuda" and isinstance(
autotvm.DispatchContext.current, autotvm.FallbackContext
):
tophub_context = autotvm.tophub.context(target)
else:
tophub_context = autotvm.utils.EmptyContext()
pass_ctx = transform.PassContext.current()
pass_config = dict(pass_ctx.config)
pass_config.setdefault("relay.backend.use_meta_schedule", True)
pass_config.setdefault("relay.backend.tir_converter", "default")
with Profiler.timeit("TaskExtraction"):
with target, _autotvm_silencer(), tophub_context:
with transform.PassContext(
opt_level=opt_level,
opt_level=pass_ctx.opt_level,
required_pass=pass_ctx.required_pass,
disabled_pass=pass_ctx.disabled_pass,
instruments=pass_ctx.instruments,
config=pass_config,
disabled_pass=disabled_pass,
):
return list(_extract_task(mod, target, params, module_equality))

Expand Down Expand Up @@ -254,7 +236,6 @@ def tune_relay(
seed: Optional[int] = None,
module_equality: str = "structural",
num_tuning_cores: Union[Literal["physical", "logical"], int] = "physical",
disabled_pass: Optional[Union[List[str], Set[str], Tuple[str]]] = None,
) -> Database:
"""Tune a Relay program.
Expand Down Expand Up @@ -304,18 +285,14 @@ def tune_relay(
For the definition of the anchor block, see tir/analysis/analysis.py.
num_tuning_cores : Union[Literal["physical", "logical"], int]
The number of CPU cores to use during tuning.
disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]]
The list of disabled passes during tasks extraction
Returns
-------
database : Database
The database that contains the tuning records
"""
tasks, task_weights = extracted_tasks_to_tune_contexts(
extracted_tasks=extract_tasks(
mod, target, params, module_equality=module_equality, disabled_pass=disabled_pass
),
extracted_tasks=extract_tasks(mod, target, params, module_equality=module_equality),
work_dir=work_dir,
space=space,
strategy=strategy,
Expand Down Expand Up @@ -346,15 +323,7 @@ def compile_relay(
params: Optional[Dict[str, NDArray]],
*,
backend: Literal["graph", "vm"] = "graph",
opt_level: int = 3,
pass_config: Mapping[str, Any] = MappingProxyType(
{
"relay.backend.use_meta_schedule": True,
"relay.backend.tir_converter": "default",
}
),
executor: Optional["relay.backend.Executor"] = None,
disabled_pass: Optional[Union[List[str], Set[str], Tuple[str]]] = None,
):
"""Compile a relay program with a MetaSchedule database.
Expand All @@ -372,14 +341,8 @@ def compile_relay(
The backend to use. Builtin backends:
- "graph"
- "vm"
opt_level : int
The optimization level of the compilation
pass_config : Mapping[str, Any]
The pass configuration
executor : Optional[relay.backend.Executor]
The executor to use in relay.build. It is not supported by RelayVM.
disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]]
The list of disabled passes
Returns
-------
Expand All @@ -390,16 +353,20 @@ def compile_relay(
from tvm import relay

# pylint: enable=import-outside-toplevel
mod, target, params, pass_config, executor = _normalize_params(
mod, target, params, pass_config, executor
)
mod, target, params, executor = _normalize_params(mod, target, params, executor)
pass_ctx = transform.PassContext.current()
pass_config = dict(pass_ctx.config)
pass_config.setdefault("relay.backend.use_meta_schedule_dispatch", True)
pass_config.setdefault("relay.backend.use_meta_schedule", True)
pass_config.setdefault("relay.backend.tir_converter", "default")
with Profiler.timeit("PostTuningCompilation"):
with target, _autotvm_silencer(), database:
with transform.PassContext(
opt_level=opt_level,
opt_level=pass_ctx.opt_level,
required_pass=pass_ctx.required_pass,
disabled_pass=pass_ctx.disabled_pass,
instruments=pass_ctx.instruments,
config=pass_config,
disabled_pass=disabled_pass,
):
if backend == "graph":
return relay.build(mod, target=target, params=params, executor=executor)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def schedule_conv2d_for_tune(sch: Schedule):
else None
)

with tempfile.TemporaryDirectory() as work_dir:
with tempfile.TemporaryDirectory() as work_dir, pass_context:
database = ms.relay_integration.tune_relay(
mod=mod,
target=TARGET_HEXAGON,
Expand Down Expand Up @@ -382,15 +382,11 @@ def schedule_conv2d_for_tune(sch: Schedule):
module_equality="ignore-ndarray",
)

# Add default options so that it still uses the base config.
pass_config["relay.backend.use_meta_schedule"] = True
pass_config["relay.backend.tir_converter"] = "default"
return ms.relay_integration.compile_relay(
database=database,
mod=mod,
target=TARGET_HEXAGON,
params=params,
pass_config=pass_config,
)


Expand Down

0 comments on commit b26d443

Please sign in to comment.