Skip to content

Commit

Permalink
[Metaschedule] New relay backend for meta schedule task extraction (#…
Browse files Browse the repository at this point in the history
…10578)

* New relay backend for meta schedule task extraction

commit 501fac6
Merge: 076fa33 ce8c563
Author: Masahiro Masuda <[email protected]>
Date:   Fri Mar 11 14:16:47 2022 +0900

    New relay backend for meta schedule task extraction

commit ce8c563
Author: Masahiro Masuda <[email protected]>
Date:   Fri Mar 11 14:12:30 2022 +0900

    fix cpplint

commit dfa4fb0
Author: Masahiro Masuda <[email protected]>
Date:   Fri Mar 11 14:09:11 2022 +0900

    update expected op list in
    test_meta_schedule_integration_extract_from_resnet to remove dep on Ansor

commit a98182e
Author: Masahiro Masuda <[email protected]>
Date:   Fri Mar 11 13:56:35 2022 +0900

    fixed test_meta_schedule_integration_apply_history_best

commit 40d52a1
Author: Masahiro Masuda <[email protected]>
Date:   Fri Mar 11 13:50:43 2022 +0900

    uniquefy task names

commit dfaf496
Author: Masahiro Masuda <[email protected]>
Date:   Fri Mar 11 13:45:30 2022 +0900

    dedup tasks

commit e49d500
Author: Masahiro Masuda <[email protected]>
Date:   Fri Mar 11 12:59:45 2022 +0900

    return reversed list

commit 74636be
Author: Masahiro Masuda <[email protected]>
Date:   Fri Mar 11 12:39:58 2022 +0900

    refactor

commit 99f1701
Author: Masahiro Masuda <[email protected]>
Date:   Fri Mar 11 12:34:14 2022 +0900

    clean up integration.cc and Query interface

commit 3f93a1e
Author: Masahiro Masuda <[email protected]>
Date:   Fri Mar 11 11:54:57 2022 +0900

    check in minor vnni-related change

commit af3e988
Author: Masahiro Masuda <[email protected]>
Date:   Fri Mar 11 07:36:35 2022 +0900

    Removed TaskExtraction node

commit 7b4d35e
Author: Masahiro Masuda <[email protected]>
Date:   Fri Mar 11 05:42:56 2022 +0900

    add doc to util functions

commit 3c5a318
Author: Masahiro Masuda <[email protected]>
Date:   Fri Mar 11 05:27:53 2022 +0900

    rename to task extraction

commit 57f2882
Author: Masahiro Masuda <[email protected]>
Date:   Fri Mar 11 05:24:37 2022 +0900

    fixed constant param bind

commit f099537
Author: Masahiro Masuda <[email protected]>
Date:   Fri Mar 11 05:10:44 2022 +0900

    remove unused stuff from python extract_tasks_from_relay

commit 4a5e4aa
Author: Masahiro Masuda <[email protected]>
Date:   Fri Mar 11 05:10:30 2022 +0900

    move BindParams function to cc file

commit efeccea
Author: Masahiro Masuda <[email protected]>
Date:   Fri Mar 11 03:56:05 2022 +0900

    refactor param binding

commit 109187f
Author: Masahiro Masuda <[email protected]>
Date:   Fri Mar 11 02:21:58 2022 +0900

    New relay backend for meta schedule task extraction

commit 6f01901
Author: Masahiro Masuda <[email protected]>
Date:   Fri Mar 11 11:25:44 2022 +0900

    fixed anchor impl selection

commit be6c258
Author: Masahiro Masuda <[email protected]>
Date:   Fri Mar 11 10:57:02 2022 +0900

    Forgot visiting arg in ScheduleBuilder CallNode vsit

commit 0c6d4a6
Author: Masahiro Masuda <[email protected]>
Date:   Fri Mar 11 10:45:08 2022 +0900

    add public, fix include path convention

commit 4cd3a16
Author: Masahiro Masuda <[email protected]>
Date:   Thu Mar 10 18:43:15 2022 +0900

    removed create_schedule stuff

commit eb1bc7e
Author: Masahiro Masuda <[email protected]>
Date:   Thu Mar 10 18:13:42 2022 +0900

    fixed merge conflict

commit 6e68fd9
Author: Masahiro Masuda <[email protected]>
Date:   Thu Mar 10 14:27:34 2022 +0900

    Decouple TE compute and schedule lowering in ScheduleBuilder

* update integration.h doc

* remove unused import

* fix mypy check

* use_meta_schedule restored, now extracts the same task as Ansor

* python doc update

* unused import

* cache_ -> cache, suppres "Cannot find workdload" warning

* Update src/relay/backend/task_extraction.cc and te_compiler_cache.cc

Co-authored-by: Junru Shao <[email protected]>

* removed unnecessary include

* fixed build

* drop relay.const on params

* updated comment in integration.cc

* update schedule_rule name to prepend "metaschedule"

* typo fix

* more nit change

* make the output of Query Optional

* update py doc

* remove TODO comment on parse_mod

Co-authored-by: Junru Shao <[email protected]>
  • Loading branch information
masahi and junrushao authored Mar 16, 2022
1 parent ab4289d commit ce335c3
Show file tree
Hide file tree
Showing 13 changed files with 264 additions and 245 deletions.
62 changes: 13 additions & 49 deletions include/tvm/meta_schedule/integration.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,12 @@ class MetaScheduleContextNode : public runtime::Object {
* \param target Target info
* \param dispatched A list of low-level IRs that the high-level IR could potentially dispatch to.
* NullOpt means the dispatch needs to be done in the context.
* \return There are different types of the output
* 1) NullOpt if there is no feedback hint
* 2) tir::PrimFunc if `mod` should be lowered to a PrimFunc
* 3) relay::Function if `mod` should be dispatched to BYOC workflow
* 4) IRModule for unified dispatch
* \return IRModule or NullOpt Currently we only have to return tir::PrimFunc, but we wrap it
* under IRModule for more general future use. NullOpt is returned
* if there is no feedback hint.
*/
virtual Optional<ObjectRef> Query(runtime::String task_name, IRModule mod, Target target,
Optional<Array<IRModule>> dispatched) = 0;
virtual Optional<IRModule> Query(runtime::String task_name, IRModule mod, Target target,
Optional<Array<IRModule>> dispatched) = 0;

static constexpr const char* _type_key = "meta_schedule.MetaScheduleContext";
TVM_DECLARE_BASE_OBJECT_INFO(MetaScheduleContextNode, runtime::Object);
Expand Down Expand Up @@ -123,15 +121,13 @@ class MetaScheduleContext : public runtime::ObjectRef {
* \param mod The high-level IR
* \param target Target info
* \param dispatched A list of low-level IRs that the high-level IR could potentially dispatch to
* \return There are different types of the output
* 1) NullOpt if there is no feedback hint
* 2) tir::PrimFunc if `mod` should be lowered to a PrimFunc
* 3) relay::Function if `mod` should be dispatched to BYOC workflow
* 4) IRModule for unified dispatch
* \return IRModule or NullOpt Currently we only have to return tir::PrimFunc, but we wrap it
* under IRModule for more general future use. NullOpt is returned
* if there is no feedback hint
*/
static Optional<ObjectRef> QueryInsideWithScope(runtime::String task_name, IRModule mod,
Target target,
Optional<Array<IRModule>> dispatched);
static Optional<IRModule> QueryInsideWithScope(runtime::String task_name, IRModule mod,
Target target,
Optional<Array<IRModule>> dispatched);

TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(MetaScheduleContext, runtime::ObjectRef,
MetaScheduleContextNode);
Expand All @@ -145,38 +141,6 @@ class MetaScheduleContext : public runtime::ObjectRef {
void ExitWithScope();
};

/**************** TaskExtraction ****************/

/*!
* \brief An integration context for task extraction
*/
class TaskExtractionNode : public MetaScheduleContextNode {
public:
/*! \brief The extracted tasks */
Array<ExtractedTask> tasks{nullptr};

void VisitAttrs(AttrVisitor* v) { v->Visit("tasks", &tasks); }

// Inherited from base class
Optional<ObjectRef> Query(runtime::String task_name, IRModule mod, Target target,
Optional<Array<IRModule>> dispatched) final;

static constexpr const char* _type_key = "meta_schedule.TaskExtraction";
TVM_DECLARE_FINAL_OBJECT_INFO(TaskExtractionNode, MetaScheduleContextNode);
};

/*!
* \brief Managed reference to TaskExtractionNode
* \sa TaskExtractionNode
*/
class TaskExtraction : public MetaScheduleContext {
public:
/*! \brief The path to a cache file storing extracted tasks */
TaskExtraction();
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskExtraction, MetaScheduleContext,
TaskExtractionNode);
};

/**************** ApplyHistoryBest ****************/

/*!
Expand All @@ -193,8 +157,8 @@ class ApplyHistoryBestNode : public MetaScheduleContextNode {
}

// Inherited from base class
Optional<ObjectRef> Query(runtime::String task_name, IRModule mod, Target target,
Optional<Array<IRModule>> dispatched) final;
Optional<IRModule> Query(runtime::String task_name, IRModule mod, Target target,
Optional<Array<IRModule>> dispatched) final;

static constexpr const char* _type_key = "meta_schedule.ApplyHistoryBest";
TVM_DECLARE_FINAL_OBJECT_INFO(ApplyHistoryBestNode, MetaScheduleContextNode);
Expand Down
90 changes: 30 additions & 60 deletions python/tvm/meta_schedule/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@
# specific language governing permissions and limitations
# under the License.
"""Meta schedule integration with high-level IR"""
from contextlib import contextmanager
from typing import Callable, Dict, List, Optional, Union
from typing import Dict, List, Optional, Union

from tvm._ffi import register_object
import numpy as np # type: ignore
import tvm.runtime.ndarray as nd

from tvm._ffi import register_object, get_global_func
from tvm.ir import IRModule, transform
from tvm.relay import Any
from tvm.relay import Function as RelayFunc
from tvm.relay import vm
from tvm.runtime import NDArray, Object
from tvm.target import Target
from tvm.tir import PrimFunc

from . import _ffi_api
from .database import Database
Expand Down Expand Up @@ -77,7 +77,7 @@ def query(
mod: IRModule,
target: Target,
dispatched: Optional[List[IRModule]],
) -> Union[IRModule, RelayFunc, PrimFunc, None]:
) -> Union[IRModule, None]:
"""The entry point of the integration
Parameters
Expand All @@ -93,12 +93,9 @@ def query(
Returns
-------
result : Union[IRModule, RelayFunc, PrimFunc, None]
There are different types of the output:
1) NullOpt if there is no feedback hint;
2) tir::PrimFunc if `mod` should be lowered to a PrimFunc;
3) relay::Function if `mod` should be dispatched to BYOC workflow;
4) IRModule for unified dispatch
result : IRModule or None
Currently we only have to return tir::PrimFunc, but we wrap it under IRModule for
more general future use. None is returned if there is no feedback hint.
"""
return _ffi_api.MetaScheduleContextQuery( # type: ignore # pylint: disable=no-member
self,
Expand Down Expand Up @@ -126,7 +123,7 @@ def query_inside_with_scope(
mod: IRModule,
target: Target,
dispatched: Optional[List[IRModule]],
) -> Union[IRModule, RelayFunc, PrimFunc, None]:
) -> Union[IRModule, None]:
"""The entry point of the integration workflow. The compilation process of the high-level
IR should call this method for task extraction and for feedback hints
Expand All @@ -137,7 +134,7 @@ def query_inside_with_scope(
def query_inside_with_scope(task_name, mod, dispatched):
ctx = MetaScheduleContext.current()
assert ctx is not None
ctx.query(task_name, mod, target, dispatched)
mod = ctx.query(task_name, mod, target, dispatched)
Parameters
----------
Expand All @@ -152,12 +149,9 @@ def query_inside_with_scope(task_name, mod, dispatched):
Returns
-------
result : Union[IRModule, RelayFunc, PrimFunc, None]
There are different types of the output:
1) NullOpt if there is no feedback hint;
2) tir::PrimFunc if `mod` should be lowered to a PrimFunc;
3) relay::Function if `mod` should be dispatched to BYOC workflow;
4) IRModule for unified dispatch
result : IRModule or None
Currently we only have to return tir::PrimFunc, but we wrap it under IRModule for
more general future use. None is returned if there is no feedback hint.
"""
return _ffi_api.MetaScheduleContextQueryInsideWithScope( # type: ignore # pylint: disable=no-member
task_name,
Expand All @@ -176,17 +170,6 @@ def __exit__(self, ptype, value, trace) -> None:
_ffi_api.MetaScheduleContextExitScope(self) # type: ignore # pylint: disable=no-member


@register_object("meta_schedule.TaskExtraction")
class TaskExtraction(MetaScheduleContext):
"""An integration context for task extraction"""

tasks: List[ExtractedTask]
"""The extracted tasks"""

def __init__(self) -> None:
self.__init_handle_by_constructor__(_ffi_api.TaskExtraction) # type: ignore # pylint: disable=no-member


@register_object("meta_schedule.ApplyHistoryBest")
class ApplyHistoryBest(MetaScheduleContext):
"""An integration context that allows application of historically best record from database"""
Expand Down Expand Up @@ -230,45 +213,32 @@ def extract_task_from_relay(
The tasks extracted from this network
"""

@contextmanager
def _autotvm_silencer():
from tvm import autotvm # pylint: disable=import-outside-toplevel

silent = autotvm.GLOBAL_SCOPE.silent
autotvm.GLOBAL_SCOPE.silent = True
try:
yield
finally:
autotvm.GLOBAL_SCOPE.silent = silent
extract_task_func = get_global_func("relay.backend.MetaScheduleExtractTask")
assert extract_task_func

def _thread_run(func: Callable[[], None]) -> None:
import threading # pylint: disable=import-outside-toplevel
target = Target(target) if isinstance(target, str) else target

thread = threading.Thread(target=func)
thread.start()
thread.join()
relay_params = {}
for name, param in params.items():
if isinstance(param, np.ndarray):
param = nd.array(param)
relay_params[name] = param

if disabled_pass is None:
disabled_pass = []
if pass_config is None:
pass_config = {"relay.backend.use_meta_schedule": True}

env = TaskExtraction()
if isinstance(mod, RelayFunc):
mod = IRModule.from_expr(mod)
if not isinstance(target, Target):
target = Target(target)

def _func():
with env, _autotvm_silencer(), transform.PassContext(
config=pass_config,
disabled_pass=disabled_pass,
opt_level=opt_level,
):
compiler = vm.VMCompiler()
if params:
compiler.set_params(params)
compiler.lower(mod, target)

_thread_run(_func)
return env.tasks
with target, transform.PassContext(
opt_level=opt_level,
config=pass_config,
disabled_pass=disabled_pass,
):
tasks = extract_task_func(mod, target, relay_params)
# Tasks are extracted via post order visit, return the reversed list.
return list(reversed(tasks))
1 change: 1 addition & 0 deletions python/tvm/topi/x86/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def batch_matmul_vnni_compute(cfg, x, y):
axis=ak,
),
tag="batch_matmul_vnni",
attrs={"schedule_rule": "meta_schedule.batch_matmul_vnni"},
)

_, a_y, _ = z.op.axis
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/x86/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def dense_vnni_compute(cfg, X, packed_w, bias=None):
axis=ak,
),
tag="dense_vnni",
attrs={"schedule_rule": "meta_schedule.dense_vnni"},
)

if bias is not None:
Expand Down
Loading

0 comments on commit ce335c3

Please sign in to comment.