Skip to content

Commit

Permalink
[MetaSchedule] Multi-Level-Tiling & Auto-Inline (apache#503)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored Nov 9, 2021
1 parent bfb8561 commit 2ba9aad
Show file tree
Hide file tree
Showing 27 changed files with 2,323 additions and 54 deletions.
47 changes: 44 additions & 3 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ class PyScheduleRuleNode : public ScheduleRuleNode {
*/
using FAsString = runtime::TypedPackedFunc<String()>;

/*! \brief The packed function to the `InitializeWithTuneContext` funcion. */
/*! \brief The packed function to the `InitializeWithTuneContext` function. */
FInitializeWithTuneContext f_initialize_with_tune_context;
/*! \brief The packed function to the `Apply` funcion. */
/*! \brief The packed function to the `Apply` function. */
FApply f_apply;
/*! \brief The packed function to the `AsString` funcion. */
/*! \brief The packed function to the `AsString` function. */
FAsString f_as_string;

void VisitAttrs(tvm::AttrVisitor* v) {
Expand Down Expand Up @@ -110,6 +110,47 @@ class PyScheduleRuleNode : public ScheduleRuleNode {
*/
class ScheduleRule : public runtime::ObjectRef {
public:
/*!
* \brief Create an auto-inline rule that inlines spatial blocks if it satisfies some conditions
* \brief into_producer If allows to inline a block into its producer
* \brief into_consumer If allows to inline a block into its consumer
* \brief into_cache_only If it only allows to inline into a block generated by cache_read/write
* \param inline_const_tensor Always inline constant tensors
* \param disallow_if_then_else Always disallow if-then-else-like constructs
* \param require_ordered Always require the read-to-write mapping to be ordered
* \param require_injective Always require the read-to-write mapping to be injective
* \param disallow_op The operators that are disallowed in auto inline
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule AutoInline(bool into_producer, //
bool into_consumer, //
bool into_cache_only, //
bool inline_const_tensor, //
bool disallow_if_then_else, //
bool require_injective, //
bool require_ordered, //
Optional<Array<String>> disallow_op);
/*!
* \brief Create a mega rule: multi-level tiling with data reuse
* \param structure The tiling structure. Recommended:
* - 'SSRSRS' on CPU
* - 'SSSRRSRS' on GPU
* \param tile_bind For each level of tiles, which thread axis it is bound to. Recommended:
* - NullOpt on CPU
* - [blockIdx.x, vthread.x, threadIdx.x] on GPU
* \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit
* \param vector_load_max_len The length of vector lane in vectorized cooperative fetching.
* NullOpt means disable vectorization
* \param reuse_read Data reuse configuration for reading. NullOpt means no reuse.
* \param reuse_write Data reuse configuration for writing. NullOpt means no reuse.
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule MultiLevelTiling(String structure, //
Optional<Array<String>> tile_binds, //
Optional<Integer> max_innermost_factor, //
Optional<Integer> vector_load_max_len, //
Optional<Map<String, ObjectRef>> reuse_read, //
Optional<Map<String, ObjectRef>> reuse_write);
/*!
* \brief Create a schedule rule with customized methods on the python-side.
* \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
Expand Down
22 changes: 21 additions & 1 deletion include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1224,7 +1224,7 @@ class BlockRealize : public Stmt {
TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockRealizeNode);
};

/*! \brief namespace of possible attribute sin AttrStmt.attr_key */
/*! \brief namespace of possible attributes in AttrStmt.attr_key */
namespace attr {
// The above attr does not pass to ir stage.
/*! \brief Mark launching extent of thread, used by device API. */
Expand Down Expand Up @@ -1355,6 +1355,26 @@ constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_
*/
constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint";

/*!
* \brief Mark that the loop should be further skip and bound to environment threads to enable
* cooperative fetching.
*/
constexpr const char* meta_schedule_lazy_cooperative_fetch = "meta_schedule.lazy_cooperative_fetch";

/*!
* \brief Mark a block as generated by cache_read or cache_write block.
* 0 means cache_read; 1 means cache_write.
* \sa meta_schedule_cache_type_read
* \sa meta_schedule_cache_type_write
*/
constexpr const char* meta_schedule_cache_type = "meta_schedule.cache_type";

/*! \sa meta_schedule_cache_type */
constexpr const int meta_schedule_cache_type_read = 0;

/*! \sa meta_schedule_cache_type */
constexpr const int meta_schedule_cache_type_write = 1;

/*!
* \brief Check if attr_key is a pragma key extension
* \param attr_key The attr key to be compared
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/meta_schedule/schedule_rule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,6 @@
Meta Schedule schedule rules are used for modification of
blocks in a schedule. See also PostOrderApply.
"""
from .schedule_rule import ScheduleRule, PyScheduleRule
from .auto_inline import AutoInline
from .multi_level_tiling import MultiLevelTiling, ReuseType
from .schedule_rule import PyScheduleRule, ScheduleRule
71 changes: 71 additions & 0 deletions python/tvm/meta_schedule/schedule_rule/auto_inline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Auto-Inline. Rule that inlines spatial blocks if it satisfies some conditions"""
from typing import List, Optional

from tvm._ffi import register_object

from .. import _ffi_api
from .schedule_rule import ScheduleRule


@register_object("meta_schedule.AutoInline")
class AutoInline(ScheduleRule):
"""Rule that inlines spatial blocks if it satisfies some conditions
Parameters
----------
into_producer : bool
If allows to inline a block into its producer
into_consumer : bool
If allows to inline a block into its consumer
into_cache_only : bool
If it only allows to inline into a block generated by cache_read/write
inline_const_tensor : bool
Always inline constant tensors
disallow_if_then_else : bool
Always disallow if-then-else-like constructs
require_injective : bool
Always require the read-to-write mapping to be ordered
require_ordered : bool
Always require the read-to-write mapping to be injective
disallow_op : Optional[List[str]]
The operators that are disallowed in auto inline
"""

def __init__(
self,
into_producer: bool,
into_consumer: bool,
into_cache_only: bool,
inline_const_tensor: bool,
disallow_if_then_else: bool,
require_injective: bool,
require_ordered: bool,
disallow_op: Optional[List[str]] = None,
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ScheduleRuleAutoInline, # type: ignore # pylint: disable=no-member
into_producer,
into_consumer,
into_cache_only,
inline_const_tensor,
disallow_if_then_else,
require_injective,
require_ordered,
disallow_op,
)
84 changes: 84 additions & 0 deletions python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Multi-level tiling with reuse."""
from typing import Any, Dict, List, Literal, NamedTuple, Optional

from tvm._ffi import register_object

from .. import _ffi_api
from .schedule_rule import ScheduleRule


class ReuseType(NamedTuple):
"""Reuse type."""

req: Literal["no", "may", "must"]
levels: List[int]
scope: str

def as_dict(self) -> Dict[str, Any]:
"""Return the dict representation of the reuse type."""
return {
"req": self.req,
"levels": self.levels,
"scope": self.scope,
}


@register_object("meta_schedule.MultiLevelTiling")
class MultiLevelTiling(ScheduleRule):
"""Multi-level tiling with reuse.
Parameters
----------
structure : str
The tiling structure. Recommended:
- 'SSRSRS' on CPU
- 'SSSRRSRS' on GPU
tile_bind : Optional[List[str]]
For each level of tiles, which thread axis it is bound to. Recommended:
- None on CPU
- [blockIdx.x, vthread.x, threadIdx.x] on GPU
max_innermost_factor : Optional[int]
The maximum size of the innermost factor. None means no limit
vector_load_max_len : Optional[int]
The length of vector lane in vectorized cooperative fetching.
None means disable vectorization
reuse_read : Optional[ReuseType]
Data reuse configuration for reading. None means no reuse.
reuse_write : Optional[ReuseType]
Data reuse configuration for writing. None means no reuse.
"""

def __init__(
self,
structure: str,
tile_binds: Optional[List[str]] = None,
max_innermost_factor: Optional[int] = None,
vector_load_max_len: Optional[int] = None,
reuse_read: Optional[ReuseType] = None,
reuse_write: Optional[ReuseType] = None,
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ScheduleRuleMultiLevelTiling, # type: ignore # pylint: disable=no-member
structure,
tile_binds,
max_innermost_factor,
vector_load_max_len,
reuse_read.as_dict() if reuse_read is not None else None,
reuse_write.as_dict() if reuse_write is not None else None,
)
8 changes: 5 additions & 3 deletions python/tvm/meta_schedule/schedule_rule/schedule_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def initialize_with_tune_context(self, tune_context: "TuneContext") -> None:
self, tune_context
)

def apply(self, schedule: Schedule, block: BlockRV) -> List[Schedule]:
def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
"""Apply a schedule rule to the specific block in the given schedule.
Parameters
Expand All @@ -62,7 +62,9 @@ def apply(self, schedule: Schedule, block: BlockRV) -> List[Schedule]:
design_spaces : List[Schedule]
The list of schedules generated by applying the schedule rule.
"""
return _ffi_api.ScheduleRuleApply(self, schedule, block)
return _ffi_api.ScheduleRuleApply( # type: ignore # pylint: disable=no-member
self, sch, block
)


@register_object("meta_schedule.PyScheduleRule")
Expand Down Expand Up @@ -91,4 +93,4 @@ def f_as_string() -> str:
)

def __str__(self) -> str:
return f"PyScheduleRule({_get_hex_address(self.handle)})"
return f"{self.__class__.__name__}({_get_hex_address(self.handle)})"
2 changes: 2 additions & 0 deletions python/tvm/meta_schedule/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
"""Testing utilities in meta schedule"""
from . import te_workload
from . import schedule_rule
from .local_rpc import LocalRPC
from .relay_workload import MODEL_TYPE, MODEL_TYPES, get_network, get_torch_model
from .te_workload import create_te_workload
Loading

0 comments on commit 2ba9aad

Please sign in to comment.