diff --git a/python/tvm/meta_schedule/integration.py b/python/tvm/meta_schedule/integration.py index e5d98624f7107..2477aa11536cd 100644 --- a/python/tvm/meta_schedule/integration.py +++ b/python/tvm/meta_schedule/integration.py @@ -23,7 +23,7 @@ from tvm._ffi import register_object, get_global_func from tvm.ir import IRModule, transform -from tvm.relay import Any +from tvm.relay import Any, const from tvm.relay import Function as RelayFunc from tvm.relay import vm from tvm.runtime import NDArray, Object @@ -238,9 +238,11 @@ def extract_task_from_relay( target = Target(target) if isinstance(target, str) else target + relay_params = {} for name, param in params.items(): if isinstance(param, np.ndarray): - params[name] = nd.array(param) + param = nd.array(param) + relay_params[name] = const(param) if disabled_pass is None: disabled_pass = [] @@ -250,11 +252,10 @@ def extract_task_from_relay( if not isinstance(target, Target): target = Target(target) - with transform.PassContext( + with target, transform.PassContext( opt_level=opt_level, config=pass_config, disabled_pass=disabled_pass, ): - with target: - tasks = extract_task_func(mod, target, params) - return tasks + tasks = extract_task_func(mod, target, relay_params) + return tasks diff --git a/src/relay/backend/metaschedule_task_extraction.cc b/src/relay/backend/metaschedule_task_extraction.cc index c0c7a525f3d3c..1492a824bec55 100644 --- a/src/relay/backend/metaschedule_task_extraction.cc +++ b/src/relay/backend/metaschedule_task_extraction.cc @@ -35,15 +35,7 @@ namespace metaschedule { using meta_schedule::ExtractedTask; Array ExtractTask(IRModule mod, Target target, Map params) { - // backend::BindParamsInModule(mod, params); - if (params.size()) { - std::unordered_map params_; - BaseFunc base_func = mod->Lookup("main"); - ICHECK(base_func->IsInstance()); - auto f = relay::backend::BindParamsByName(Downcast(base_func), params_); - auto gvar = mod->GetGlobalVar("main"); - mod->Add(gvar, f); - } + backend::BindParamsInModule(mod, params); Array pass_seqs = relay::backend::GetPassPrefix(/*is_homogenous=*/true, /*is_vm=*/true); pass_seqs.push_back(transform::FuseOps());