Skip to content

Commit

Permalink
[MetaSchedule] Add pass instrument to MetaSchedule api (apache#13688)
Browse files Browse the repository at this point in the history
* [MetaSchedule] Add pass instrument to MetaSchedule api

Add the `instrument` parameter from the `PassContext` api to the meta
schedule tuning api.

* lint
  • Loading branch information
Tristan Konolige authored and fzi-peccia committed Mar 27, 2023
1 parent b249b9a commit 292d088
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions python/tvm/meta_schedule/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,18 @@
"""MetaSchedule-Relay integration"""
from contextlib import contextmanager
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, Set
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence, Set, Tuple, Union

# isort: off
from typing_extensions import Literal

# isort: on
import numpy as np # type: ignore

from tvm import nd
from tvm._ffi import get_global_func
from tvm.ir import IRModule, transform
from tvm.ir.instrument import PassInstrument
from tvm.runtime import NDArray
from tvm.target import Target

Expand Down Expand Up @@ -127,6 +129,7 @@ def extract_tasks(
runtime: Optional["relay.backend.Runtime"] = None,
module_equality: str = "structural",
disabled_pass: Optional[Union[List[str], Set[str], Tuple[str]]] = None,
instruments: Optional[Sequence[PassInstrument]] = None,
) -> List[ExtractedTask]:
"""Extract tuning tasks from a relay program.
Expand Down Expand Up @@ -158,6 +161,8 @@ def extract_tasks(
For the definition of the anchor block, see tir/analysis/analysis.py.
disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]]
The list of disabled passes
instruments : Optional[Sequence[PassInstrument]]
The list of pass instrument implementations.
Returns
-------
Expand Down Expand Up @@ -188,6 +193,7 @@ def extract_tasks(
opt_level=opt_level,
config=pass_config,
disabled_pass=disabled_pass,
instruments=instruments,
):
return list(_extract_task(mod, target, params, module_equality))

Expand Down Expand Up @@ -268,6 +274,7 @@ def tune_relay(
module_equality: str = "structural",
num_tuning_cores: Union[Literal["physical", "logical"], int] = "physical",
disabled_pass: Optional[Union[List[str], Set[str], Tuple[str]]] = None,
instruments: Optional[Sequence[PassInstrument]] = None,
) -> Database:
"""Tune a Relay program.
Expand Down Expand Up @@ -319,6 +326,8 @@ def tune_relay(
The number of CPU cores to use during tuning.
disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]]
The list of disabled passes during tasks extraction
instruments : Optional[Sequence[PassInstrument]]
The list of pass instrument implementations.
Returns
-------
Expand All @@ -327,7 +336,12 @@ def tune_relay(
"""
tasks, task_weights = extracted_tasks_to_tune_contexts(
extracted_tasks=extract_tasks(
mod, target, params, module_equality=module_equality, disabled_pass=disabled_pass
mod,
target,
params,
module_equality=module_equality,
disabled_pass=disabled_pass,
instruments=instruments,
),
work_dir=work_dir,
space=space,
Expand Down Expand Up @@ -369,6 +383,7 @@ def compile_relay(
executor: Optional["relay.backend.Executor"] = None,
disabled_pass: Optional[Union[List[str], Set[str], Tuple[str]]] = None,
runtime: Optional["relay.backend.Runtime"] = None,
instruments: Optional[Sequence[PassInstrument]] = None,
):
"""Compile a relay program with a MetaSchedule database.
Expand Down Expand Up @@ -396,6 +411,8 @@ def compile_relay(
The list of disabled passes
runtime : Optional[relay.backend.Runtime]
The runtime to use in relay.build. It is not supported by RelayVM.
instruments : Optional[Sequence[PassInstrument]]
The list of pass instrument implementations.
Returns
-------
Expand All @@ -416,6 +433,7 @@ def compile_relay(
opt_level=opt_level,
config=pass_config,
disabled_pass=disabled_pass,
instruments=instruments,
):
if backend == "graph":
return relay.build(
Expand Down

0 comments on commit 292d088

Please sign in to comment.