Skip to content

Commit

Permalink
New relay backend for meta schedule task extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Mar 11, 2022
2 parents 076fa33 + ce8c563 commit 501fac6
Show file tree
Hide file tree
Showing 13 changed files with 254 additions and 235 deletions.
38 changes: 3 additions & 35 deletions include/tvm/meta_schedule/integration.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class MetaScheduleContextNode : public runtime::Object {
* 3) relay::Function if `mod` should be dispatched to BYOC workflow
* 4) IRModule for unified dispatch
*/
virtual Optional<ObjectRef> Query(runtime::String task_name, IRModule mod, Target target,
virtual IRModule Query(runtime::String task_name, IRModule mod, Target target,
Optional<Array<IRModule>> dispatched) = 0;

static constexpr const char* _type_key = "meta_schedule.MetaScheduleContext";
Expand Down Expand Up @@ -129,7 +129,7 @@ class MetaScheduleContext : public runtime::ObjectRef {
* 3) relay::Function if `mod` should be dispatched to BYOC workflow
* 4) IRModule for unified dispatch
*/
static Optional<ObjectRef> QueryInsideWithScope(runtime::String task_name, IRModule mod,
static IRModule QueryInsideWithScope(runtime::String task_name, IRModule mod,
Target target,
Optional<Array<IRModule>> dispatched);

Expand All @@ -145,38 +145,6 @@ class MetaScheduleContext : public runtime::ObjectRef {
void ExitWithScope();
};

/**************** TaskExtraction ****************/

/*!
* \brief An integration context for task extraction
*/
class TaskExtractionNode : public MetaScheduleContextNode {
public:
/*! \brief The extracted tasks */
Array<ExtractedTask> tasks{nullptr};

void VisitAttrs(AttrVisitor* v) { v->Visit("tasks", &tasks); }

// Inherited from base class
Optional<ObjectRef> Query(runtime::String task_name, IRModule mod, Target target,
Optional<Array<IRModule>> dispatched) final;

static constexpr const char* _type_key = "meta_schedule.TaskExtraction";
TVM_DECLARE_FINAL_OBJECT_INFO(TaskExtractionNode, MetaScheduleContextNode);
};

/*!
* \brief Managed reference to TaskExtractionNode
* \sa TaskExtractionNode
*/
class TaskExtraction : public MetaScheduleContext {
public:
/*! \brief The path to a cache file storing extracted tasks */
TaskExtraction();
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskExtraction, MetaScheduleContext,
TaskExtractionNode);
};

/**************** ApplyHistoryBest ****************/

/*!
Expand All @@ -193,7 +161,7 @@ class ApplyHistoryBestNode : public MetaScheduleContextNode {
}

// Inherited from base class
Optional<ObjectRef> Query(runtime::String task_name, IRModule mod, Target target,
IRModule Query(runtime::String task_name, IRModule mod, Target target,
Optional<Array<IRModule>> dispatched) final;

static constexpr const char* _type_key = "meta_schedule.ApplyHistoryBest";
Expand Down
65 changes: 21 additions & 44 deletions python/tvm/meta_schedule/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
from contextlib import contextmanager
from typing import Callable, Dict, List, Optional, Union

from tvm._ffi import register_object
import numpy as np
import tvm.runtime.ndarray as nd

from tvm._ffi import register_object, get_global_func
from tvm.ir import IRModule, transform
from tvm.relay import Any
from tvm.relay import Any, const
from tvm.relay import Function as RelayFunc
from tvm.relay import vm
from tvm.runtime import NDArray, Object
Expand Down Expand Up @@ -176,17 +179,6 @@ def __exit__(self, ptype, value, trace) -> None:
_ffi_api.MetaScheduleContextExitScope(self) # type: ignore # pylint: disable=no-member


@register_object("meta_schedule.TaskExtraction")
class TaskExtraction(MetaScheduleContext):
"""An integration context for task extraction"""

tasks: List[ExtractedTask]
"""The extracted tasks"""

def __init__(self) -> None:
self.__init_handle_by_constructor__(_ffi_api.TaskExtraction) # type: ignore # pylint: disable=no-member


@register_object("meta_schedule.ApplyHistoryBest")
class ApplyHistoryBest(MetaScheduleContext):
"""An integration context that allows application of historically best record from database"""
Expand Down Expand Up @@ -230,45 +222,30 @@ def extract_task_from_relay(
The tasks extracted from this network
"""

@contextmanager
def _autotvm_silencer():
from tvm import autotvm # pylint: disable=import-outside-toplevel

silent = autotvm.GLOBAL_SCOPE.silent
autotvm.GLOBAL_SCOPE.silent = True
try:
yield
finally:
autotvm.GLOBAL_SCOPE.silent = silent
extract_task_func = get_global_func("relay.backend.MetaScheduleExtractTask")
assert extract_task_func

def _thread_run(func: Callable[[], None]) -> None:
import threading # pylint: disable=import-outside-toplevel
target = Target(target) if isinstance(target, str) else target

thread = threading.Thread(target=func)
thread.start()
thread.join()
relay_params = {}
for name, param in params.items():
if isinstance(param, np.ndarray):
param = nd.array(param)
relay_params[name] = const(param)

if disabled_pass is None:
disabled_pass = []
if pass_config is None:
pass_config = {"relay.backend.use_meta_schedule": True}

env = TaskExtraction()
if isinstance(mod, RelayFunc):
mod = IRModule.from_expr(mod)
if not isinstance(target, Target):
target = Target(target)

def _func():
with env, _autotvm_silencer(), transform.PassContext(
config=pass_config,
disabled_pass=disabled_pass,
opt_level=opt_level,
):
compiler = vm.VMCompiler()
if params:
compiler.set_params(params)
compiler.lower(mod, target)

_thread_run(_func)
return env.tasks
with target, transform.PassContext(
opt_level=opt_level,
config=pass_config,
disabled_pass=disabled_pass,
):
tasks = extract_task_func(mod, target, relay_params)
# Tasks are extracted via post order visit, return the reversed list.
return list(reversed(tasks))
1 change: 1 addition & 0 deletions python/tvm/topi/x86/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def batch_matmul_vnni_compute(cfg, x, y):
axis=ak,
),
tag="batch_matmul_vnni",
attrs={"schedule_rule": "batch_matmul_vnni"},
)

_, a_y, _ = z.op.axis
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/x86/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def dense_vnni_compute(cfg, X, packed_w, bias=None):
axis=ak,
),
tag="dense_vnni",
attrs={"schedule_rule": "dense_vnni"},
)

if bias is not None:
Expand Down
62 changes: 27 additions & 35 deletions src/meta_schedule/integration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@ namespace tvm {
namespace meta_schedule {

/**************** Utility functions ****************/

template <class FunctionType>
Optional<FunctionType> GetOnlyOneFunction(const IRModule& mod) {
template <class FunctionType, class RetType, class Callback>
Optional<RetType> GetOnlyOneFunctionCommon(const IRModule& mod, Callback on_found) {
if (mod->functions.size() != 1) {
return NullOpt;
}
Expand All @@ -37,12 +36,23 @@ Optional<FunctionType> GetOnlyOneFunction(const IRModule& mod) {
if (!func->IsInstance<typename FunctionType::ContainerType>()) {
return NullOpt;
} else {
return Downcast<FunctionType>(func);
return on_found(kv);
}
}
return NullOpt;
}

template <class FunctionType>
Optional<GlobalVar> GetOnlyOneFunctionKey(const IRModule& mod) {
return GetOnlyOneFunctionCommon<FunctionType, GlobalVar>(mod, [](auto kv) { return kv.first; });
}

template <class FunctionType>
Optional<FunctionType> GetOnlyOneFunction(const IRModule& mod) {
return GetOnlyOneFunctionCommon<FunctionType, FunctionType>(
mod, [](auto kv) { return Downcast<FunctionType>(kv.second); });
}

template <class FunctionType>
bool HasOnlyOneFunction(const IRModule& mod) {
return GetOnlyOneFunction<FunctionType>(mod).defined();
Expand Down Expand Up @@ -86,31 +96,13 @@ void MetaScheduleContext::ExitWithScope() {
ctx = NullOpt;
}

Optional<ObjectRef> MetaScheduleContext::QueryInsideWithScope(
runtime::String task_name, IRModule mod, Target target, Optional<Array<IRModule>> dispatched) {
IRModule MetaScheduleContext::QueryInsideWithScope(runtime::String task_name, IRModule mod,
Target target,
Optional<Array<IRModule>> dispatched) {
if (Optional<MetaScheduleContext> ctx = MetaScheduleContext::Current()) {
return ctx.value()->Query(task_name, mod, target, dispatched);
}
return NullOpt;
}

/**************** TaskExtraction ****************/

TaskExtraction::TaskExtraction() {
ObjectPtr<TaskExtractionNode> n = make_object<TaskExtractionNode>();
n->tasks = Array<ExtractedTask>();
data_ = n;
}

Optional<ObjectRef> TaskExtractionNode::Query(runtime::String task_name, IRModule mod,
Target target, Optional<Array<IRModule>> dispatched) {
ICHECK(dispatched.defined());
ICHECK_EQ(dispatched.value().size(), 1);
IRModule prim_mod = dispatched.value()[0];
ICHECK(HasOnlyOneFunction<tir::PrimFunc>(prim_mod)) << prim_mod;
ICHECK(HasOnlyOneFunction<relay::Function>(mod)) << mod;
tasks.push_back(ExtractedTask(task_name, mod, target, {prim_mod}));
return NullOpt;
return IRModule{nullptr};
}

/**************** ApplyHistoryBest ****************/
Expand All @@ -121,14 +113,18 @@ ApplyHistoryBest::ApplyHistoryBest(Database database) {
data_ = n;
}

Optional<ObjectRef> ApplyHistoryBestNode::Query(runtime::String task_name, IRModule mod,
Target target,
Optional<Array<IRModule>> dispatched) {
IRModule ApplyHistoryBestNode::Query(runtime::String task_name, IRModule mod, Target target,
Optional<Array<IRModule>> dispatched) {
ICHECK(dispatched.defined());
ICHECK_EQ(dispatched.value().size(), 1);
ICHECK(HasOnlyOneFunction<relay::Function>(mod)) << mod;
IRModule prim_mod = dispatched.value()[0];
ICHECK(HasOnlyOneFunction<tir::PrimFunc>(prim_mod)) << prim_mod;
// TODO(masahi): parse_mod below replaces the orginal function key with "main".
// This is necessary because some scheduling primitives requires the PrimFunc key be "main".
// If we can remove this restriction, there would no need for GetOnlyOneFunction* calls below
// and we can directly return sch->mod().
auto gv = GetOnlyOneFunctionKey<tir::PrimFunc>(prim_mod).value();
// Unify func name to make sure it can be found in database
const auto* parse_mod_func = runtime::Registry::Get("tvm.meta_schedule.tune.parse_mod");
ICHECK(parse_mod_func) << "Parse mod function not defined!";
Expand All @@ -141,11 +137,11 @@ Optional<ObjectRef> ApplyHistoryBestNode::Query(runtime::String task_name, IRMod
/*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone);
records[0]->trace->ApplyToSchedule(sch, false);
tir::PrimFunc func = GetOnlyOneFunction<tir::PrimFunc>(sch->mod()).value();
return func;
return IRModule({{gv, func}});
}
}
LOG(WARNING) << "Cannot find workload: " << task_name << "\n" << tir::AsTVMScript(prim_mod);
return NullOpt;
return IRModule{nullptr};
}

/**************** FFI ****************/
Expand All @@ -158,7 +154,6 @@ class MetaScheduleContextInternal {

TVM_REGISTER_NODE_TYPE(ExtractedTaskNode);
TVM_REGISTER_OBJECT_TYPE(MetaScheduleContextNode);
TVM_REGISTER_NODE_TYPE(TaskExtractionNode);
TVM_REGISTER_NODE_TYPE(ApplyHistoryBestNode);

TVM_REGISTER_GLOBAL("meta_schedule.ExtractedTask")
Expand All @@ -176,9 +171,6 @@ TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextQueryInsideWithScope")
.set_body_typed(MetaScheduleContext::QueryInsideWithScope);
TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextQuery")
.set_body_method<MetaScheduleContext>(&MetaScheduleContextNode::Query);
TVM_REGISTER_GLOBAL("meta_schedule.TaskExtraction").set_body_typed([]() -> TaskExtraction {
return TaskExtraction();
});
TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBest")
.set_body_typed([](Database database) -> ApplyHistoryBest {
return ApplyHistoryBest(database);
Expand Down
9 changes: 1 addition & 8 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -333,14 +333,7 @@ class RelayBuildModule : public runtime::ModuleNode {
IRModule OptimizeImpl(IRModule relay_module) {
ICHECK(relay_module.defined()) << "The IRModule must be defined for the Relay compiler.";

if (!params_.empty()) {
ICHECK(relay_module->ContainGlobalVar("main")) << "Missing the main entry function";
GlobalVar main_glb_var = relay_module->GetGlobalVar("main");
Function main_func = Downcast<Function>(relay_module->Lookup(main_glb_var));
auto new_main = BindParamsByName(main_func, params_);
IRModuleNode* relay_module_ptr = relay_module.CopyOnWrite();
relay_module_ptr->Update(main_glb_var, new_main);
}
backend::BindParamsInModule(relay_module, params_);

Array<Pass> pass_seqs = GetPassPrefix(
/*is_homogenous=*/config_->optional_homogeneous_target.defined(), /*is_vm=*/false);
Expand Down
Loading

0 comments on commit 501fac6

Please sign in to comment.