Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule] Enable Task Filtering #11512

Merged
merged 1 commit into from
May 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions python/tvm/meta_schedule/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
# specific language governing permissions and limitations
# under the License.
"""MetaSchedule-Relay integration"""
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional

import numpy as np # type: ignore
from tvm import nd
from tvm._ffi import get_global_func
from tvm.ir import IRModule, transform
from tvm.runtime import NDArray
from tvm.target import Target
from tvm.te import Tensor

from .extracted_task import ExtractedTask
from .utils import autotvm_silencer
Expand All @@ -36,6 +37,7 @@ def extract_task_from_relay(
opt_level: int = 3,
pass_config: Optional[Dict[str, Any]] = None,
disabled_pass: Optional[List[str]] = None,
filter_func: Callable[[List[Tensor]], bool] = None,
) -> List[ExtractedTask]:
"""Extract tuning tasks from a relay program.

Expand All @@ -53,6 +55,8 @@ def extract_task_from_relay(
The pass config of the compiler
disabled_pass : Optional[List[str]]
The list of disabled passes of the compiler
filter_func : Callable[[List[tvm.te.Tensor]], bool]
The filter function to filter out the extracted tasks

Returns
-------
Expand Down Expand Up @@ -90,4 +94,4 @@ def extract_task_from_relay(
config=pass_config,
disabled_pass=disabled_pass,
):
return list(extract_task_func(mod, target, relay_params))
return list(extract_task_func(mod, target, relay_params, filter_func))
2 changes: 1 addition & 1 deletion python/tvm/te/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from .tag import tag_scope
from .operation import placeholder, compute, scan, extern, var, size_var, const
from .operation import thread_axis, reduce_axis
from .operation import create_prim_func, create_prim_func_from_outputs
from .operation import create_prim_func

from .tensor import PlaceholderOp, ComputeOp, TensorComputeOp, ScanOp, ExternOp, HybridOp
from .autodiff import gradient
Expand Down
29 changes: 5 additions & 24 deletions python/tvm/te/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,18 @@
# specific language governing permissions and limitations
# under the License.
""" Operation class for computation declaration."""
import inspect

# pylint: disable=invalid-name
from numbers import Integral as _Integral
from typing import List, Union
import inspect
from typing import List

import tvm._ffi
import tvm.tir
import tvm.tir._ffi_api
from tvm._ffi.base import string_types
from tvm.ir import Array
from tvm.runtime import convert
import tvm.tir
import tvm.tir._ffi_api

from . import _ffi_api
from . import tag as _tag
Expand Down Expand Up @@ -528,23 +529,3 @@ def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
if not isinstance(ops, (list, tuple, Array)):
ops = [ops]
return _ffi_api.CreatePrimFunc(ops)


def create_prim_func_from_outputs(
outputs: Union[_tensor.Tensor, List[_tensor.Tensor]],
) -> tvm.tir.PrimFunc:
"""Create a TensorIR PrimFunc from output tensor(s) in TE

Parameters
----------
outputs : Union[Tensor, List[Tensor]]
The source expression.

Returns
-------
func : tir.PrimFunc
The created function.
"""
if not isinstance(outputs, (list, tuple, Array)):
outputs = [outputs]
return _ffi_api.CreatePrimFuncFromOutputs(outputs)
80 changes: 55 additions & 25 deletions src/relay/backend/task_extraction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,58 @@ namespace tvm {
namespace relay {
namespace backend {

namespace metaschedule {

using meta_schedule::ExtractedTask;
bool DefaultTaskFilter(const Array<te::Tensor>& args) {
using namespace ::tvm::te;
std::vector<Tensor> stack;
std::unordered_set<const TensorNode*> visited;
for (const Tensor& v : args) {
for (const PrimExpr& e : v->shape) {
// Dynamic shape is not supported for now
if (!e->IsInstance<IntImmNode>()) {
return false;
junrushao marked this conversation as resolved.
Show resolved Hide resolved
}
}
if (!visited.count(v.get())) {
visited.insert(v.get());
stack.push_back(v);
}
}
while (!stack.empty()) {
Tensor tensor = stack.back();
stack.pop_back();
if (tensor->op->IsInstance<PlaceholderOpNode>()) {
// do nothing
} else if (tensor->op->IsInstance<ComputeOpNode>()) {
Array<Tensor> inputs = tensor->op->InputTensors();
for (const Tensor& v : inputs) {
if (!visited.count(v.get())) {
visited.insert(v.get());
stack.push_back(v);
}
}
} else {
return false;
}
}
return true;
}

Array<ExtractedTask> ExtractTask(IRModule mod, Target target,
Map<String, runtime::NDArray> params) {
Array<meta_schedule::ExtractedTask> ExtractTask(
IRModule mod, Target target, Map<String, runtime::NDArray> params,
runtime::TypedPackedFunc<bool(const Array<te::Tensor>&)> filter_func) {
using meta_schedule::ExtractedTask;
if (filter_func == nullptr) {
filter_func = DefaultTaskFilter;
}
backend::BindParamsInModule(mod, params);

// is_vm=true for backward compatibility
Array<Pass> pass_seqs = relay::backend::GetPassPrefix(/*is_homogenous=*/true, /*is_vm=*/true);
pass_seqs.push_back(transform::FuseOps());

transform::Sequential seq(pass_seqs);
auto opt_mod = seq(std::move(mod));
mod = transform::Sequential(pass_seqs)(std::move(mod));

std::vector<ExtractedTask> tasks;
std::unordered_map<tec::CCacheKey, ExtractedTask> cache;

PostOrderVisit(opt_mod->Lookup("main"), [target, &tasks, &cache](const Expr& exp) {
PostOrderVisit(mod->Lookup("main"), [&target, &tasks, &cache, &filter_func](const Expr& exp) {
if (exp->IsInstance<FunctionNode>()) {
Function relay_func = Downcast<Function>(exp);
if (!relay_func->HasNonzeroAttr(attr::kPrimitive)) {
Expand All @@ -61,17 +94,19 @@ Array<ExtractedTask> ExtractTask(IRModule mod, Target target,
it->second->weight += 1;
return;
}
Array<te::Tensor> inputs_outputs;
Array<te::Tensor> inputs_outputs{nullptr};
std::string fused_name;
std::tie(inputs_outputs, fused_name) =
tec::LowerTECompute(relay_func, target, /*return_inputs=*/true);
auto prim_func = tir::CreatePrimFunc(inputs_outputs);
GlobalVar prim_fn_var(fused_name);
IRModule relay_mod({{prim_fn_var, relay_func}});
IRModule tir_mod({{prim_fn_var, prim_func}});
ExtractedTask extracted_task(fused_name, relay_mod, target, {tir_mod}, 1);
tasks.push_back(extracted_task);
cache.emplace(cache_key, extracted_task);
if (filter_func(inputs_outputs)) {
tir::PrimFunc prim_func = tir::CreatePrimFunc(inputs_outputs);
GlobalVar prim_fn_var(fused_name);
IRModule relay_mod({{prim_fn_var, relay_func}});
IRModule tir_mod({{prim_fn_var, prim_func}});
ExtractedTask extracted_task(fused_name, relay_mod, target, {tir_mod}, 1);
tasks.push_back(extracted_task);
cache.emplace(cache_key, extracted_task);
}
}
});
// Tasks are extracted via post order visit, return the reversed list.
Expand All @@ -83,12 +118,7 @@ Array<ExtractedTask> ExtractTask(IRModule mod, Target target,
return tasks;
}

} // namespace metaschedule

TVM_REGISTER_GLOBAL("relay.backend.MetaScheduleExtractTask")
.set_body_typed([](IRModule mod, Target target, Map<String, runtime::NDArray> params) {
return metaschedule::ExtractTask(mod, target, params);
});
TVM_REGISTER_GLOBAL("relay.backend.MetaScheduleExtractTask").set_body_typed(ExtractTask);

} // namespace backend
} // namespace relay
Expand Down
33 changes: 0 additions & 33 deletions src/te/operation/create_primfunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -458,40 +458,7 @@ PrimFunc CreatePrimFunc(const Array<te::Tensor>& arg_list) {
return LayoutFreePlaceholdersNormalizer().Process(std::move(func));
}

PrimFunc CreatePrimFuncFromOutputs(const Array<te::Tensor>& outputs) {
std::vector<te::Tensor> stack;
std::unordered_set<const te::TensorNode*> visited;
for (const te::Tensor& output : outputs) {
if (!visited.count(output.get())) {
visited.insert(output.get());
stack.push_back(output);
}
}

Array<te::Tensor> arg_list;
while (!stack.empty()) {
te::Tensor tensor = stack.back();
stack.pop_back();
if (tensor->op->IsInstance<te::PlaceholderOpNode>()) {
arg_list.push_back(tensor);
} else if (tensor->op->IsInstance<te::ComputeOpNode>()) {
Array<te::Tensor> inputs = tensor->op->InputTensors();
for (const te::Tensor& input : inputs) {
if (!visited.count(input.get())) {
visited.insert(input.get());
stack.push_back(input);
}
}
}
}
for (const te::Tensor& output : outputs) {
arg_list.push_back(output);
}
return CreatePrimFunc(arg_list);
}

TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body_typed(CreatePrimFunc);
TVM_REGISTER_GLOBAL("te.CreatePrimFuncFromOutputs").set_body_typed(CreatePrimFuncFromOutputs);

} // namespace tir
} // namespace tvm
3 changes: 0 additions & 3 deletions src/te/operation/create_primfunc.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@ namespace tir {
/*! \brief Use Tensor Expression to create a schedulable TensorIR func. */
PrimFunc CreatePrimFunc(const Array<te::Tensor>& arg_list);

/*! \brief Create a schedulable TensorIR func from TE compute outputs. */
PrimFunc CreatePrimFuncFromOutputs(const Array<te::Tensor>& outputs);

} // namespace tir
} // namespace tvm

Expand Down
20 changes: 10 additions & 10 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,16 +199,16 @@ Schedule ConcreteScheduleNode::Copy() {
* \param level An ScheduleErrorRenderLevel enum, level of error rendering
* \sa ScheduleErrorRenderLevel
*/
#define TVM_TIR_SCHEDULE_END(primitive, level) \
} \
catch (const ScheduleError& error) { \
if ((level) == ScheduleErrorRenderLevel::kDetail) { \
throw tvm::runtime::Error(error.RenderReport(primitive)); \
} else if ((level) == ScheduleErrorRenderLevel::kFast) { \
throw tvm::runtime::Error(error.FastErrorString()); \
} else if ((level) == ScheduleErrorRenderLevel::kNone) { \
throw tvm::runtime::Error("ScheduleError: (not rendered)"); \
} \
#define TVM_TIR_SCHEDULE_END(primitive, level) \
} \
catch (const ScheduleError& error) { \
if ((level) == ScheduleErrorRenderLevel::kDetail) { \
throw tvm::runtime::Error(error.RenderReport(primitive) + "\n" + runtime::Backtrace()); \
} else if ((level) == ScheduleErrorRenderLevel::kFast) { \
throw tvm::runtime::Error(error.FastErrorString()); \
} else if ((level) == ScheduleErrorRenderLevel::kNone) { \
throw tvm::runtime::Error("ScheduleError: (not rendered)"); \
} \
}

/******** Schedule: Schedule: Sampling ********/
Expand Down
63 changes: 63 additions & 0 deletions tests/python/unittest/test_meta_schedule_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,69 @@ def test_meta_schedule_integration_extract_from_bert_base():
assert expected_shape == shape, t.task_name


@requires_torch
def test_meta_schedule_integration_extract_from_resnet_with_filter_func():
def filter_func(args) -> bool:
from tvm import te, tir

has_complex_op = False
visited = set()

def traverse(t):
nonlocal has_complex_op
assert t.handle is not None
if t.handle.value in visited:
return
if isinstance(t.op, te.PlaceholderOp):
pass
elif isinstance(t.op, te.ComputeOp):
has_complex_op = has_complex_op or any(
[isinstance(e, tir.Reduce) for e in t.op.body]
)
for x in t.op.input_tensors:
traverse(x)
visited.add(t.handle.value)

for t in args:
traverse(t)
return has_complex_op

mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
extracted_tasks = ms.extract_task_from_relay(
mod,
target="llvm",
params=params,
filter_func=filter_func,
)
expected_task_names = [
"fused_" + s
for s in [
"nn_max_pool2d",
"nn_adaptive_avg_pool2d",
"nn_dense_add",
"nn_conv2d_add",
"nn_conv2d_add_1",
"nn_conv2d_add_2",
"nn_conv2d_add_add_nn_relu",
"nn_conv2d_add_add_nn_relu_1",
"nn_conv2d_add_nn_relu",
"nn_conv2d_add_nn_relu_1",
"nn_conv2d_add_nn_relu_2",
"nn_conv2d_add_nn_relu_3",
"nn_conv2d_add_nn_relu_4",
"nn_conv2d_add_nn_relu_5",
"nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu",
"nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu_1",
"nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu",
"nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu_1",
]
]

assert len(extracted_tasks) == len(expected_task_names)
for t in extracted_tasks:
assert t.task_name in expected_task_names, t.task_name


@requires_torch
def test_meta_schedule_integration_apply_history_best():
mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
Expand Down