Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Metaschedule] New relay backend for meta schedule task extraction #10578

Merged
merged 19 commits into from
Mar 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 13 additions & 49 deletions include/tvm/meta_schedule/integration.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,12 @@ class MetaScheduleContextNode : public runtime::Object {
* \param target Target info
* \param dispatched A list of low-level IRs that the high-level IR could potentially dispatch to.
* NullOpt means the dispatch needs to be done in the context.
* \return There are different types of the output
* 1) NullOpt if there is no feedback hint
* 2) tir::PrimFunc if `mod` should be lowered to a PrimFunc
* 3) relay::Function if `mod` should be dispatched to BYOC workflow
* 4) IRModule for unified dispatch
* \return IRModule or NullOpt Currently we only have to return tir::PrimFunc, but we wrap it
* under IRModule for more general future use. NullOpt is returned
* if there is no feedback hint.
*/
virtual Optional<ObjectRef> Query(runtime::String task_name, IRModule mod, Target target,
Optional<Array<IRModule>> dispatched) = 0;
virtual Optional<IRModule> Query(runtime::String task_name, IRModule mod, Target target,
Optional<Array<IRModule>> dispatched) = 0;

static constexpr const char* _type_key = "meta_schedule.MetaScheduleContext";
TVM_DECLARE_BASE_OBJECT_INFO(MetaScheduleContextNode, runtime::Object);
Expand Down Expand Up @@ -123,15 +121,13 @@ class MetaScheduleContext : public runtime::ObjectRef {
* \param mod The high-level IR
* \param target Target info
* \param dispatched A list of low-level IRs that the high-level IR could potentially dispatch to
* \return There are different types of the output
* 1) NullOpt if there is no feedback hint
* 2) tir::PrimFunc if `mod` should be lowered to a PrimFunc
* 3) relay::Function if `mod` should be dispatched to BYOC workflow
* 4) IRModule for unified dispatch
* \return IRModule or NullOpt Currently we only have to return tir::PrimFunc, but we wrap it
* under IRModule for more general future use. NullOpt is returned
* if there is no feedback hint
*/
static Optional<ObjectRef> QueryInsideWithScope(runtime::String task_name, IRModule mod,
Target target,
Optional<Array<IRModule>> dispatched);
static Optional<IRModule> QueryInsideWithScope(runtime::String task_name, IRModule mod,
Target target,
Optional<Array<IRModule>> dispatched);

TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(MetaScheduleContext, runtime::ObjectRef,
MetaScheduleContextNode);
Expand All @@ -145,38 +141,6 @@ class MetaScheduleContext : public runtime::ObjectRef {
void ExitWithScope();
};

/**************** TaskExtraction ****************/

/*!
* \brief An integration context for task extraction
*/
class TaskExtractionNode : public MetaScheduleContextNode {
public:
/*! \brief The extracted tasks */
Array<ExtractedTask> tasks{nullptr};

void VisitAttrs(AttrVisitor* v) { v->Visit("tasks", &tasks); }

// Inherited from base class
Optional<ObjectRef> Query(runtime::String task_name, IRModule mod, Target target,
Optional<Array<IRModule>> dispatched) final;

static constexpr const char* _type_key = "meta_schedule.TaskExtraction";
TVM_DECLARE_FINAL_OBJECT_INFO(TaskExtractionNode, MetaScheduleContextNode);
};

/*!
* \brief Managed reference to TaskExtractionNode
* \sa TaskExtractionNode
*/
class TaskExtraction : public MetaScheduleContext {
public:
/*! \brief The path to a cache file storing extracted tasks */
TaskExtraction();
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskExtraction, MetaScheduleContext,
TaskExtractionNode);
};

/**************** ApplyHistoryBest ****************/

/*!
Expand All @@ -193,8 +157,8 @@ class ApplyHistoryBestNode : public MetaScheduleContextNode {
}

// Inherited from base class
Optional<ObjectRef> Query(runtime::String task_name, IRModule mod, Target target,
Optional<Array<IRModule>> dispatched) final;
Optional<IRModule> Query(runtime::String task_name, IRModule mod, Target target,
Optional<Array<IRModule>> dispatched) final;

static constexpr const char* _type_key = "meta_schedule.ApplyHistoryBest";
TVM_DECLARE_FINAL_OBJECT_INFO(ApplyHistoryBestNode, MetaScheduleContextNode);
Expand Down
90 changes: 30 additions & 60 deletions python/tvm/meta_schedule/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@
# specific language governing permissions and limitations
# under the License.
"""Meta schedule integration with high-level IR"""
from contextlib import contextmanager
from typing import Callable, Dict, List, Optional, Union
from typing import Dict, List, Optional, Union

from tvm._ffi import register_object
import numpy as np # type: ignore
import tvm.runtime.ndarray as nd

from tvm._ffi import register_object, get_global_func
from tvm.ir import IRModule, transform
from tvm.relay import Any
from tvm.relay import Function as RelayFunc
from tvm.relay import vm
from tvm.runtime import NDArray, Object
from tvm.target import Target
from tvm.tir import PrimFunc

from . import _ffi_api
from .database import Database
Expand Down Expand Up @@ -77,7 +77,7 @@ def query(
mod: IRModule,
target: Target,
dispatched: Optional[List[IRModule]],
) -> Union[IRModule, RelayFunc, PrimFunc, None]:
) -> Union[IRModule, None]:
"""The entry point of the integration

Parameters
Expand All @@ -93,12 +93,9 @@ def query(

Returns
-------
result : Union[IRModule, RelayFunc, PrimFunc, None]
There are different types of the output:
1) NullOpt if there is no feedback hint;
2) tir::PrimFunc if `mod` should be lowered to a PrimFunc;
3) relay::Function if `mod` should be dispatched to BYOC workflow;
4) IRModule for unified dispatch
result : IRModule or None
Currently we only have to return tir::PrimFunc, but we wrap it under IRModule for
more general future use. None is returned if there is no feedback hint.
"""
return _ffi_api.MetaScheduleContextQuery( # type: ignore # pylint: disable=no-member
self,
Expand Down Expand Up @@ -126,7 +123,7 @@ def query_inside_with_scope(
mod: IRModule,
target: Target,
dispatched: Optional[List[IRModule]],
) -> Union[IRModule, RelayFunc, PrimFunc, None]:
) -> Union[IRModule, None]:
"""The entry point of the integration workflow. The compilation process of the high-level
IR should call this method for task extraction and for feedback hints

Expand All @@ -137,7 +134,7 @@ def query_inside_with_scope(
def query_inside_with_scope(task_name, mod, dispatched):
ctx = MetaScheduleContext.current()
assert ctx is not None
ctx.query(task_name, mod, target, dispatched)
mod = ctx.query(task_name, mod, target, dispatched)

Parameters
----------
Expand All @@ -152,12 +149,9 @@ def query_inside_with_scope(task_name, mod, dispatched):

Returns
-------
result : Union[IRModule, RelayFunc, PrimFunc, None]
There are different types of the output:
1) NullOpt if there is no feedback hint;
2) tir::PrimFunc if `mod` should be lowered to a PrimFunc;
3) relay::Function if `mod` should be dispatched to BYOC workflow;
4) IRModule for unified dispatch
result : IRModule or None
Currently we only have to return tir::PrimFunc, but we wrap it under IRModule for
more general future use. None is returned if there is no feedback hint.
"""
return _ffi_api.MetaScheduleContextQueryInsideWithScope( # type: ignore # pylint: disable=no-member
task_name,
Expand All @@ -176,17 +170,6 @@ def __exit__(self, ptype, value, trace) -> None:
_ffi_api.MetaScheduleContextExitScope(self) # type: ignore # pylint: disable=no-member


@register_object("meta_schedule.TaskExtraction")
class TaskExtraction(MetaScheduleContext):
"""An integration context for task extraction"""

tasks: List[ExtractedTask]
"""The extracted tasks"""

def __init__(self) -> None:
self.__init_handle_by_constructor__(_ffi_api.TaskExtraction) # type: ignore # pylint: disable=no-member


@register_object("meta_schedule.ApplyHistoryBest")
class ApplyHistoryBest(MetaScheduleContext):
"""An integration context that allows application of historically best record from database"""
Expand Down Expand Up @@ -230,45 +213,32 @@ def extract_task_from_relay(
The tasks extracted from this network
"""

@contextmanager
def _autotvm_silencer():
from tvm import autotvm # pylint: disable=import-outside-toplevel

silent = autotvm.GLOBAL_SCOPE.silent
autotvm.GLOBAL_SCOPE.silent = True
try:
yield
finally:
autotvm.GLOBAL_SCOPE.silent = silent
extract_task_func = get_global_func("relay.backend.MetaScheduleExtractTask")
assert extract_task_func

def _thread_run(func: Callable[[], None]) -> None:
import threading # pylint: disable=import-outside-toplevel
target = Target(target) if isinstance(target, str) else target

thread = threading.Thread(target=func)
thread.start()
thread.join()
relay_params = {}
for name, param in params.items():
if isinstance(param, np.ndarray):
param = nd.array(param)
relay_params[name] = param

if disabled_pass is None:
disabled_pass = []
if pass_config is None:
pass_config = {"relay.backend.use_meta_schedule": True}

env = TaskExtraction()
if isinstance(mod, RelayFunc):
mod = IRModule.from_expr(mod)
if not isinstance(target, Target):
target = Target(target)

def _func():
with env, _autotvm_silencer(), transform.PassContext(
masahi marked this conversation as resolved.
Show resolved Hide resolved
config=pass_config,
disabled_pass=disabled_pass,
opt_level=opt_level,
):
compiler = vm.VMCompiler()
if params:
compiler.set_params(params)
compiler.lower(mod, target)

_thread_run(_func)
return env.tasks
with target, transform.PassContext(
opt_level=opt_level,
config=pass_config,
disabled_pass=disabled_pass,
):
tasks = extract_task_func(mod, target, relay_params)
# Tasks are extracted via post order visit, return the reversed list.
return list(reversed(tasks))
1 change: 1 addition & 0 deletions python/tvm/topi/x86/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def batch_matmul_vnni_compute(cfg, x, y):
axis=ak,
),
tag="batch_matmul_vnni",
attrs={"schedule_rule": "meta_schedule.batch_matmul_vnni"},
)

_, a_y, _ = z.op.axis
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/x86/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def dense_vnni_compute(cfg, X, packed_w, bias=None):
axis=ak,
),
tag="dense_vnni",
attrs={"schedule_rule": "meta_schedule.dense_vnni"},
)

if bias is not None:
Expand Down
Loading