diff --git a/python/tvm/meta_schedule/relay_integration.py b/python/tvm/meta_schedule/relay_integration.py index 47f76830ab88..b55633817413 100644 --- a/python/tvm/meta_schedule/relay_integration.py +++ b/python/tvm/meta_schedule/relay_integration.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """MetaSchedule-Relay integration""" -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional import numpy as np # type: ignore from tvm import nd @@ -23,6 +23,7 @@ from tvm.ir import IRModule, transform from tvm.runtime import NDArray from tvm.target import Target +from tvm.te import Tensor from .extracted_task import ExtractedTask from .utils import autotvm_silencer @@ -36,6 +37,7 @@ def extract_task_from_relay( opt_level: int = 3, pass_config: Optional[Dict[str, Any]] = None, disabled_pass: Optional[List[str]] = None, + filter_func: Callable[[List[Tensor]], bool] = None, ) -> List[ExtractedTask]: """Extract tuning tasks from a relay program. @@ -53,6 +55,8 @@ def extract_task_from_relay( The pass config of the compiler disabled_pass : Optional[List[str]] The list of disabled passes of the compiler + filter_func : Callable[[List[tvm.te.Tensor]], bool] + The filter function to filter out the extracted tasks Returns ------- @@ -90,4 +94,4 @@ def extract_task_from_relay( config=pass_config, disabled_pass=disabled_pass, ): - return list(extract_task_func(mod, target, relay_params)) + return list(extract_task_func(mod, target, relay_params, filter_func)) diff --git a/python/tvm/te/__init__.py b/python/tvm/te/__init__.py index 4c4e223f2d72..1777d8707c7c 100644 --- a/python/tvm/te/__init__.py +++ b/python/tvm/te/__init__.py @@ -39,7 +39,7 @@ from .tag import tag_scope from .operation import placeholder, compute, scan, extern, var, size_var, const from .operation import thread_axis, reduce_axis -from .operation import create_prim_func, create_prim_func_from_outputs +from .operation import create_prim_func from .tensor import PlaceholderOp, ComputeOp, TensorComputeOp, ScanOp, ExternOp, HybridOp from .autodiff import gradient diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index 90d7cb5d75db..df5dd2c4ffd8 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -15,17 +15,18 @@ # specific language governing permissions and limitations # under the License. """ Operation class for computation declaration.""" +import inspect + # pylint: disable=invalid-name from numbers import Integral as _Integral -from typing import List, Union -import inspect +from typing import List import tvm._ffi +import tvm.tir +import tvm.tir._ffi_api from tvm._ffi.base import string_types from tvm.ir import Array from tvm.runtime import convert -import tvm.tir -import tvm.tir._ffi_api from . import _ffi_api from . import tag as _tag @@ -528,23 +529,3 @@ def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: if not isinstance(ops, (list, tuple, Array)): ops = [ops] return _ffi_api.CreatePrimFunc(ops) - - -def create_prim_func_from_outputs( - outputs: Union[_tensor.Tensor, List[_tensor.Tensor]], -) -> tvm.tir.PrimFunc: - """Create a TensorIR PrimFunc from output tensor(s) in TE - - Parameters - ---------- - outputs : Union[Tensor, List[Tensor]] - The source expression. - - Returns - ------- - func : tir.PrimFunc - The created function. - """ - if not isinstance(outputs, (list, tuple, Array)): - outputs = [outputs] - return _ffi_api.CreatePrimFuncFromOutputs(outputs) diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc index 0895fd42a307..6ec881111d77 100644 --- a/src/relay/backend/task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -31,25 +31,58 @@ namespace tvm { namespace relay { namespace backend { -namespace metaschedule { - -using meta_schedule::ExtractedTask; +bool DefaultTaskFilter(const Array& args) { + using namespace ::tvm::te; + std::vector stack; + std::unordered_set visited; + for (const Tensor& v : args) { + for (const PrimExpr& e : v->shape) { + // Dynamic shape is not supported for now + if (!e->IsInstance()) { + return false; + } + } + if (!visited.count(v.get())) { + visited.insert(v.get()); + stack.push_back(v); + } + } + while (!stack.empty()) { + Tensor tensor = stack.back(); + stack.pop_back(); + if (tensor->op->IsInstance()) { + // do nothing + } else if (tensor->op->IsInstance()) { + Array inputs = tensor->op->InputTensors(); + for (const Tensor& v : inputs) { + if (!visited.count(v.get())) { + visited.insert(v.get()); + stack.push_back(v); + } + } + } else { + return false; + } + } + return true; +} -Array ExtractTask(IRModule mod, Target target, - Map params) { +Array ExtractTask( + IRModule mod, Target target, Map params, + runtime::TypedPackedFunc&)> filter_func) { + using meta_schedule::ExtractedTask; + if (filter_func == nullptr) { + filter_func = DefaultTaskFilter; + } backend::BindParamsInModule(mod, params); - // is_vm=true for backward compatibility Array pass_seqs = relay::backend::GetPassPrefix(/*is_homogenous=*/true, /*is_vm=*/true); pass_seqs.push_back(transform::FuseOps()); - - transform::Sequential seq(pass_seqs); - auto opt_mod = seq(std::move(mod)); + mod = transform::Sequential(pass_seqs)(std::move(mod)); std::vector tasks; std::unordered_map cache; - - PostOrderVisit(opt_mod->Lookup("main"), [target, &tasks, &cache](const Expr& exp) { + PostOrderVisit(mod->Lookup("main"), [&target, &tasks, &cache, &filter_func](const Expr& exp) { if (exp->IsInstance()) { Function relay_func = Downcast(exp); if (!relay_func->HasNonzeroAttr(attr::kPrimitive)) { @@ -61,17 +94,19 @@ Array ExtractTask(IRModule mod, Target target, it->second->weight += 1; return; } - Array inputs_outputs; + Array inputs_outputs{nullptr}; std::string fused_name; std::tie(inputs_outputs, fused_name) = tec::LowerTECompute(relay_func, target, /*return_inputs=*/true); - auto prim_func = tir::CreatePrimFunc(inputs_outputs); - GlobalVar prim_fn_var(fused_name); - IRModule relay_mod({{prim_fn_var, relay_func}}); - IRModule tir_mod({{prim_fn_var, prim_func}}); - ExtractedTask extracted_task(fused_name, relay_mod, target, {tir_mod}, 1); - tasks.push_back(extracted_task); - cache.emplace(cache_key, extracted_task); + if (filter_func(inputs_outputs)) { + tir::PrimFunc prim_func = tir::CreatePrimFunc(inputs_outputs); + GlobalVar prim_fn_var(fused_name); + IRModule relay_mod({{prim_fn_var, relay_func}}); + IRModule tir_mod({{prim_fn_var, prim_func}}); + ExtractedTask extracted_task(fused_name, relay_mod, target, {tir_mod}, 1); + tasks.push_back(extracted_task); + cache.emplace(cache_key, extracted_task); + } } }); // Tasks are extracted via post order visit, return the reversed list. @@ -83,12 +118,7 @@ Array ExtractTask(IRModule mod, Target target, return tasks; } -} // namespace metaschedule - -TVM_REGISTER_GLOBAL("relay.backend.MetaScheduleExtractTask") - .set_body_typed([](IRModule mod, Target target, Map params) { - return metaschedule::ExtractTask(mod, target, params); - }); +TVM_REGISTER_GLOBAL("relay.backend.MetaScheduleExtractTask").set_body_typed(ExtractTask); } // namespace backend } // namespace relay diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 7e7dae855802..03ad551c6839 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -458,40 +458,7 @@ PrimFunc CreatePrimFunc(const Array& arg_list) { return LayoutFreePlaceholdersNormalizer().Process(std::move(func)); } -PrimFunc CreatePrimFuncFromOutputs(const Array& outputs) { - std::vector stack; - std::unordered_set visited; - for (const te::Tensor& output : outputs) { - if (!visited.count(output.get())) { - visited.insert(output.get()); - stack.push_back(output); - } - } - - Array arg_list; - while (!stack.empty()) { - te::Tensor tensor = stack.back(); - stack.pop_back(); - if (tensor->op->IsInstance()) { - arg_list.push_back(tensor); - } else if (tensor->op->IsInstance()) { - Array inputs = tensor->op->InputTensors(); - for (const te::Tensor& input : inputs) { - if (!visited.count(input.get())) { - visited.insert(input.get()); - stack.push_back(input); - } - } - } - } - for (const te::Tensor& output : outputs) { - arg_list.push_back(output); - } - return CreatePrimFunc(arg_list); -} - TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body_typed(CreatePrimFunc); -TVM_REGISTER_GLOBAL("te.CreatePrimFuncFromOutputs").set_body_typed(CreatePrimFuncFromOutputs); } // namespace tir } // namespace tvm diff --git a/src/te/operation/create_primfunc.h b/src/te/operation/create_primfunc.h index d911e5ebcdb7..c3cddd83f57a 100644 --- a/src/te/operation/create_primfunc.h +++ b/src/te/operation/create_primfunc.h @@ -30,9 +30,6 @@ namespace tir { /*! \brief Use Tensor Expression to create a schedulable TensorIR func. */ PrimFunc CreatePrimFunc(const Array& arg_list); -/*! \brief Create a schedulable TensorIR func from TE compute outputs. */ -PrimFunc CreatePrimFuncFromOutputs(const Array& outputs); - } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 8066d85a8e7d..2289899c329b 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -199,16 +199,16 @@ Schedule ConcreteScheduleNode::Copy() { * \param level An ScheduleErrorRenderLevel enum, level of error rendering * \sa ScheduleErrorRenderLevel */ -#define TVM_TIR_SCHEDULE_END(primitive, level) \ - } \ - catch (const ScheduleError& error) { \ - if ((level) == ScheduleErrorRenderLevel::kDetail) { \ - throw tvm::runtime::Error(error.RenderReport(primitive)); \ - } else if ((level) == ScheduleErrorRenderLevel::kFast) { \ - throw tvm::runtime::Error(error.FastErrorString()); \ - } else if ((level) == ScheduleErrorRenderLevel::kNone) { \ - throw tvm::runtime::Error("ScheduleError: (not rendered)"); \ - } \ +#define TVM_TIR_SCHEDULE_END(primitive, level) \ + } \ + catch (const ScheduleError& error) { \ + if ((level) == ScheduleErrorRenderLevel::kDetail) { \ + throw tvm::runtime::Error(error.RenderReport(primitive) + "\n" + runtime::Backtrace()); \ + } else if ((level) == ScheduleErrorRenderLevel::kFast) { \ + throw tvm::runtime::Error(error.FastErrorString()); \ + } else if ((level) == ScheduleErrorRenderLevel::kNone) { \ + throw tvm::runtime::Error("ScheduleError: (not rendered)"); \ + } \ } /******** Schedule: Schedule: Sampling ********/ diff --git a/tests/python/unittest/test_meta_schedule_integration.py b/tests/python/unittest/test_meta_schedule_integration.py index cd6e1b4c405a..a423bdb48afd 100644 --- a/tests/python/unittest/test_meta_schedule_integration.py +++ b/tests/python/unittest/test_meta_schedule_integration.py @@ -196,6 +196,69 @@ def test_meta_schedule_integration_extract_from_bert_base(): assert expected_shape == shape, t.task_name +@requires_torch +def test_meta_schedule_integration_extract_from_resnet_with_filter_func(): + def filter_func(args) -> bool: + from tvm import te, tir + + has_complex_op = False + visited = set() + + def traverse(t): + nonlocal has_complex_op + assert t.handle is not None + if t.handle.value in visited: + return + if isinstance(t.op, te.PlaceholderOp): + pass + elif isinstance(t.op, te.ComputeOp): + has_complex_op = has_complex_op or any( + [isinstance(e, tir.Reduce) for e in t.op.body] + ) + for x in t.op.input_tensors: + traverse(x) + visited.add(t.handle.value) + + for t in args: + traverse(t) + return has_complex_op + + mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224]) + extracted_tasks = ms.extract_task_from_relay( + mod, + target="llvm", + params=params, + filter_func=filter_func, + ) + expected_task_names = [ + "fused_" + s + for s in [ + "nn_max_pool2d", + "nn_adaptive_avg_pool2d", + "nn_dense_add", + "nn_conv2d_add", + "nn_conv2d_add_1", + "nn_conv2d_add_2", + "nn_conv2d_add_add_nn_relu", + "nn_conv2d_add_add_nn_relu_1", + "nn_conv2d_add_nn_relu", + "nn_conv2d_add_nn_relu_1", + "nn_conv2d_add_nn_relu_2", + "nn_conv2d_add_nn_relu_3", + "nn_conv2d_add_nn_relu_4", + "nn_conv2d_add_nn_relu_5", + "nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu", + "nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu_1", + "nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu", + "nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu_1", + ] + ] + + assert len(extracted_tasks) == len(expected_task_names) + for t in extracted_tasks: + assert t.task_name in expected_task_names, t.task_name + + @requires_torch def test_meta_schedule_integration_apply_history_best(): mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])