diff --git a/python/tvm/meta_schedule/relay_integration.py b/python/tvm/meta_schedule/relay_integration.py index 876dba106c38..8910dc17b202 100644 --- a/python/tvm/meta_schedule/relay_integration.py +++ b/python/tvm/meta_schedule/relay_integration.py @@ -71,15 +71,8 @@ def _normalize_params( mod: IRModule, target: Union[Target, str], params: Optional[Dict[str, NDArray]], - pass_config: Mapping[str, Any], executor: Optional["relay.backend.Executor"], -) -> Tuple[ - IRModule, - Target, - Dict[str, NDArray], - Dict[str, Any], - Optional["relay.backend.Executor"], -]: +) -> Tuple[IRModule, Target, Dict[str, NDArray], Optional["relay.backend.Executor"],]: from tvm import relay # pylint: disable=import-outside-toplevel if isinstance(mod, relay.Function): @@ -102,8 +95,7 @@ def _normalize_params( else: executor = mod.get_attr("executor") - pass_config = dict(pass_config) - return mod, target, relay_params, pass_config, executor + return mod, target, relay_params, executor def extract_tasks( @@ -111,16 +103,8 @@ def extract_tasks( target: Union[Target, str], params: Optional[Dict[str, NDArray]], *, - opt_level: int = 3, - pass_config: Mapping[str, Any] = MappingProxyType( - { - "relay.backend.use_meta_schedule": True, - "relay.backend.tir_converter": "default", - } - ), executor: Optional["relay.backend.Executor"] = None, module_equality: str = "structural", - disabled_pass: Optional[Union[List[str], Set[str], Tuple[str]]] = None, ) -> List[ExtractedTask]: """Extract tuning tasks from a relay program. @@ -132,10 +116,6 @@ def extract_tasks( The compilation target params : Optional[Dict[str, tvm.runtime.NDArray]] The associated parameters of the program - opt_level : int - The optimization level of the compilation - pass_config : Mapping[str, Any] - The pass configuration executor : Optional[relay.backend.Executor] The executor to use module_equality : Optional[str] @@ -148,8 +128,6 @@ def extract_tasks( given module. The "ignore-ndarray" varint is used for the extracted blocks or in case no anchor block is found. For the definition of the anchor block, see tir/analysis/analysis.py. - disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]] - The list of disabled passes Returns ------- @@ -160,21 +138,25 @@ def extract_tasks( from tvm import autotvm # pylint: enable=import-outside-toplevel - mod, target, params, pass_config, _ = _normalize_params( - mod, target, params, pass_config, executor - ) + mod, target, params, _ = _normalize_params(mod, target, params, executor) if target.kind.name != "cuda" and isinstance( autotvm.DispatchContext.current, autotvm.FallbackContext ): tophub_context = autotvm.tophub.context(target) else: tophub_context = autotvm.utils.EmptyContext() + pass_ctx = transform.PassContext.current() + pass_config = dict(pass_ctx.config) + pass_config.setdefault("relay.backend.use_meta_schedule", True) + pass_config.setdefault("relay.backend.tir_converter", "default") with Profiler.timeit("TaskExtraction"): with target, _autotvm_silencer(), tophub_context: with transform.PassContext( - opt_level=opt_level, + opt_level=pass_ctx.opt_level, + required_pass=pass_ctx.required_pass, + disabled_pass=pass_ctx.disabled_pass, + instruments=pass_ctx.instruments, config=pass_config, - disabled_pass=disabled_pass, ): return list(_extract_task(mod, target, params, module_equality)) @@ -254,7 +236,6 @@ def tune_relay( seed: Optional[int] = None, module_equality: str = "structural", num_tuning_cores: Union[Literal["physical", "logical"], int] = "physical", - disabled_pass: Optional[Union[List[str], Set[str], Tuple[str]]] = None, ) -> Database: """Tune a Relay program. @@ -304,8 +285,6 @@ def tune_relay( For the definition of the anchor block, see tir/analysis/analysis.py. num_tuning_cores : Union[Literal["physical", "logical"], int] The number of CPU cores to use during tuning. - disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]] - The list of disabled passes during tasks extraction Returns ------- @@ -313,9 +292,7 @@ def tune_relay( The database that contains the tuning records """ tasks, task_weights = extracted_tasks_to_tune_contexts( - extracted_tasks=extract_tasks( - mod, target, params, module_equality=module_equality, disabled_pass=disabled_pass - ), + extracted_tasks=extract_tasks(mod, target, params, module_equality=module_equality), work_dir=work_dir, space=space, strategy=strategy, @@ -346,15 +323,7 @@ def compile_relay( params: Optional[Dict[str, NDArray]], *, backend: Literal["graph", "vm"] = "graph", - opt_level: int = 3, - pass_config: Mapping[str, Any] = MappingProxyType( - { - "relay.backend.use_meta_schedule": True, - "relay.backend.tir_converter": "default", - } - ), executor: Optional["relay.backend.Executor"] = None, - disabled_pass: Optional[Union[List[str], Set[str], Tuple[str]]] = None, ): """Compile a relay program with a MetaSchedule database. @@ -372,14 +341,8 @@ def compile_relay( The backend to use. Builtin backends: - "graph" - "vm" - opt_level : int - The optimization level of the compilation - pass_config : Mapping[str, Any] - The pass configuration executor : Optional[relay.backend.Executor] The executor to use in relay.build. It is not supported by RelayVM. - disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]] - The list of disabled passes Returns ------- @@ -390,16 +353,20 @@ def compile_relay( from tvm import relay # pylint: enable=import-outside-toplevel - mod, target, params, pass_config, executor = _normalize_params( - mod, target, params, pass_config, executor - ) + mod, target, params, executor = _normalize_params(mod, target, params, executor) + pass_ctx = transform.PassContext.current() + pass_config = dict(pass_ctx.config) pass_config.setdefault("relay.backend.use_meta_schedule_dispatch", True) + pass_config.setdefault("relay.backend.use_meta_schedule", True) + pass_config.setdefault("relay.backend.tir_converter", "default") with Profiler.timeit("PostTuningCompilation"): with target, _autotvm_silencer(), database: with transform.PassContext( - opt_level=opt_level, + opt_level=pass_ctx.opt_level, + required_pass=pass_ctx.required_pass, + disabled_pass=pass_ctx.disabled_pass, + instruments=pass_ctx.instruments, config=pass_config, - disabled_pass=disabled_pass, ): if backend == "graph": return relay.build(mod, target=target, params=params, executor=executor) diff --git a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py index 1e01cb28a749..6ec1b4dd81c2 100644 --- a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py +++ b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py @@ -347,7 +347,7 @@ def schedule_conv2d_for_tune(sch: Schedule): else None ) - with tempfile.TemporaryDirectory() as work_dir: + with tempfile.TemporaryDirectory() as work_dir, pass_context: database = ms.relay_integration.tune_relay( mod=mod, target=TARGET_HEXAGON, @@ -382,15 +382,11 @@ def schedule_conv2d_for_tune(sch: Schedule): module_equality="ignore-ndarray", ) - # Add default options so that it still uses the base config. - pass_config["relay.backend.use_meta_schedule"] = True - pass_config["relay.backend.tir_converter"] = "default" return ms.relay_integration.compile_relay( database=database, mod=mod, target=TARGET_HEXAGON, params=params, - pass_config=pass_config, )