diff --git a/python/tvm/meta_schedule/relay_integration.py b/python/tvm/meta_schedule/relay_integration.py index fbdf68d09767..41d3f9d12ebc 100644 --- a/python/tvm/meta_schedule/relay_integration.py +++ b/python/tvm/meta_schedule/relay_integration.py @@ -17,16 +17,18 @@ """MetaSchedule-Relay integration""" from contextlib import contextmanager from types import MappingProxyType -from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, Set +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence, Set, Tuple, Union # isort: off from typing_extensions import Literal # isort: on import numpy as np # type: ignore + from tvm import nd from tvm._ffi import get_global_func from tvm.ir import IRModule, transform +from tvm.ir.instrument import PassInstrument from tvm.runtime import NDArray from tvm.target import Target @@ -127,6 +129,7 @@ def extract_tasks( runtime: Optional["relay.backend.Runtime"] = None, module_equality: str = "structural", disabled_pass: Optional[Union[List[str], Set[str], Tuple[str]]] = None, + instruments: Optional[Sequence[PassInstrument]] = None, ) -> List[ExtractedTask]: """Extract tuning tasks from a relay program. @@ -158,6 +161,8 @@ def extract_tasks( For the definition of the anchor block, see tir/analysis/analysis.py. disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]] The list of disabled passes + instruments : Optional[Sequence[PassInstrument]] + The list of pass instrument implementations. Returns ------- @@ -188,6 +193,7 @@ def extract_tasks( opt_level=opt_level, config=pass_config, disabled_pass=disabled_pass, + instruments=instruments, ): return list(_extract_task(mod, target, params, module_equality)) @@ -268,6 +274,7 @@ def tune_relay( module_equality: str = "structural", num_tuning_cores: Union[Literal["physical", "logical"], int] = "physical", disabled_pass: Optional[Union[List[str], Set[str], Tuple[str]]] = None, + instruments: Optional[Sequence[PassInstrument]] = None, ) -> Database: """Tune a Relay program. @@ -319,6 +326,8 @@ def tune_relay( The number of CPU cores to use during tuning. disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]] The list of disabled passes during tasks extraction + instruments : Optional[Sequence[PassInstrument]] + The list of pass instrument implementations. Returns ------- @@ -327,7 +336,12 @@ def tune_relay( """ tasks, task_weights = extracted_tasks_to_tune_contexts( extracted_tasks=extract_tasks( - mod, target, params, module_equality=module_equality, disabled_pass=disabled_pass + mod, + target, + params, + module_equality=module_equality, + disabled_pass=disabled_pass, + instruments=instruments, ), work_dir=work_dir, space=space, @@ -369,6 +383,7 @@ def compile_relay( executor: Optional["relay.backend.Executor"] = None, disabled_pass: Optional[Union[List[str], Set[str], Tuple[str]]] = None, runtime: Optional["relay.backend.Runtime"] = None, + instruments: Optional[Sequence[PassInstrument]] = None, ): """Compile a relay program with a MetaSchedule database. @@ -396,6 +411,8 @@ def compile_relay( The list of disabled passes runtime : Optional[relay.backend.Runtime] The runtime to use in relay.build. It is not supported by RelayVM. + instruments : Optional[Sequence[PassInstrument]] + The list of pass instrument implementations. Returns ------- @@ -416,6 +433,7 @@ def compile_relay( opt_level=opt_level, config=pass_config, disabled_pass=disabled_pass, + instruments=instruments, ): if backend == "graph": return relay.build(