Skip to content

Commit

Permalink
Layout Rewriting: Suggest-Index-Map (apache#520)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored Nov 22, 2021
1 parent 87dbb2c commit 67af07c
Show file tree
Hide file tree
Showing 14 changed files with 633 additions and 57 deletions.
2 changes: 2 additions & 0 deletions include/tvm/meta_schedule/cost_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

#include <tvm/meta_schedule/search_strategy.h>

#include <vector>

namespace tvm {
namespace meta_schedule {

Expand Down
46 changes: 46 additions & 0 deletions include/tvm/tir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,52 @@ class LinkedParam : public ObjectRef {
TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode);
};

/*! \brief A mapping from multi-dimensional indices to another set of multi-dimensional indices */
class IndexMapNode : public Object {
public:
/*! \brief The source indices */
Array<Var> src_iters;
/*! \brief The target indices */
Array<PrimExpr> tgt_iters;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("src_iters", &src_iters);
v->Visit("tgt_iters", &tgt_iters);
}

/*!
* \brief Take `inputs` as the source indices and return the corresponding target indices.
* \param inputs The source indices.
* \return The target indices.
*/
Array<PrimExpr> Apply(const Array<PrimExpr>& inputs) const;

static constexpr const char* _type_key = "tir.IndexMap";
TVM_DECLARE_FINAL_OBJECT_INFO(IndexMapNode, Object);
};

/*!
* \brief Managed reference to IndexMapNode.
* \sa IndexMapNode
*/
class IndexMap : public ObjectRef {
public:
/*!
* \brief Constructor.
* \param src_iters The source indices.
* \param tgt_iters The target indices.
*/
explicit IndexMap(Array<Var> src_iters, Array<PrimExpr> tgt_iters);
/*!
* \brief Create an index map from a packed function
* \param ndim The number of dimensions
* \param func The function to be applied
* \return The created index map
*/
static IndexMap FromFunc(int ndim, runtime::TypedPackedFunc<Array<PrimExpr>(Array<Var>)> func);
TVM_DEFINE_OBJECT_REF_METHODS(IndexMap, ObjectRef, IndexMapNode);
};

/*!
* \brief Tensor TensorIntrin for Tensorization
*/
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize

from .function import PrimFunc, TensorIntrin
from .function import PrimFunc, IndexMap, TensorIntrin

from .op import call_packed, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace
Expand Down
102 changes: 86 additions & 16 deletions python/tvm/tir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,19 @@
# under the License.
"""Function data types."""

from typing import Mapping, Union
import inspect
from typing import Callable, List, Mapping, Union

import tvm._ffi
import tvm.runtime
from tvm.runtime import Object
from tvm._ffi import get_global_func, register_object
from tvm.ir import BaseFunc
from .buffer import Buffer
from .expr import Var, PrimExpr
from tvm.runtime import Object, convert

from . import _ffi_api
from .buffer import Buffer
from .expr import PrimExpr, Var


@tvm._ffi.register_object("tir.PrimFunc")
@register_object("tir.PrimFunc")
class PrimFunc(BaseFunc):
"""A function declaration expression.
Expand Down Expand Up @@ -56,7 +57,7 @@ def __init__(self, params, body, ret_type=None, buffer_map=None, attrs=None, spa
param_list = []
buffer_map = {} if buffer_map is None else buffer_map
for x in params:
x = tvm.runtime.convert(x) if not isinstance(x, Object) else x
x = convert(x) if not isinstance(x, Object) else x
if isinstance(x, Buffer):
var = Var(x.name, dtype="handle")
param_list.append(var)
Expand All @@ -67,7 +68,13 @@ def __init__(self, params, body, ret_type=None, buffer_map=None, attrs=None, spa
raise TypeError("params can only contain Var or Buffer")

self.__init_handle_by_constructor__(
_ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs, span # type: ignore
_ffi_api.PrimFunc, # type: ignore # pylint: disable=no-member
param_list,
body,
ret_type,
buffer_map,
attrs,
span,
)

def with_body(self, new_body, span=None):
Expand Down Expand Up @@ -141,7 +148,7 @@ def mem_copy_16_16(a: T.handle, b: T.handle) -> None:
func : PrimFunc
The new function with parameter specialized
"""
return _ffi_api.Specialize(self, param_map) # type: ignore
return _ffi_api.Specialize(self, param_map) # type: ignore # pylint: disable=no-member

def script(self, tir_prefix: str = "tir", show_meta: bool = False) -> str:
"""Print IRModule into TVMScript
Expand All @@ -159,11 +166,72 @@ def script(self, tir_prefix: str = "tir", show_meta: bool = False) -> str:
script : str
The TVM Script of the PrimFunc
"""
return tvm._ffi.get_global_func("script.AsTVMScript")(
self, tir_prefix, show_meta
) # type: ignore
return get_global_func("script.AsTVMScript")(self, tir_prefix, show_meta) # type: ignore


@register_object("tir.IndexMap")
class IndexMap(Object):
"""A mapping from multi-dimensional indices to another set of multi-dimensional indices
Parameters
----------
src_iters : list of Var
The source indices
tgt_iters : list of PrimExpr
The target indices
"""

src_iters: List[Var]
"""The source indices"""

tgt_iters: List[PrimExpr]
"""The target indices"""

def __init__(self, src_iters: List[Var], tgt_iters: List[PrimExpr]):
self._init_handle_by_constructor(
_ffi_api.IndexMap, # type: ignore # pylint: disable=no-member
src_iters,
tgt_iters,
)

def apply(self, indices: List[PrimExpr]) -> List[PrimExpr]:
"""Apply the index map to a set of indices
Parameters
----------
indices : List[PriExpr]
The indices to be mapped
Returns
-------
result : List[PrimExpr]
The mapped indices
"""
return _ffi_api.IndexMapApply(self, indices) # type: ignore # pylint: disable=no-member

@staticmethod
def from_func(func: Callable) -> "IndexMap":
"""Create an index map from a function
Parameters
----------
func : Callable
The function to map from source indices to target indices
"""

def wrap(args: List[Var]) -> List[PrimExpr]:
result = func(*args)
if isinstance(result, tuple):
return list(result)
if not isinstance(result, list):
result = [result]
return result

@tvm._ffi.register_object("tir.TensorIntrin")
ndim = len(inspect.signature(func).parameters)
return _ffi_api.IndexMapFromFunc(ndim, wrap) # type: ignore # pylint: disable=no-member


@register_object("tir.TensorIntrin")
class TensorIntrin(Object):
"""A function declaration expression.
Expand All @@ -177,7 +245,9 @@ class TensorIntrin(Object):
"""

def __init__(self, desc_func, intrin_func):
self.__init_handle_by_constructor__(_ffi_api.TensorIntrin, desc_func, intrin_func)
self.__init_handle_by_constructor__(
_ffi_api.TensorIntrin, desc_func, intrin_func # type: ignore # pylint: disable=no-member
)

@staticmethod
def register(name: str, desc_func: PrimFunc, intrin_func: PrimFunc):
Expand All @@ -187,4 +257,4 @@ def register(name: str, desc_func: PrimFunc, intrin_func: PrimFunc):

@staticmethod
def get(name: str):
return _ffi_api.TensorIntrinGet(name) # pylint: disable=no-member
return _ffi_api.TensorIntrinGet(name) # pylint: disable=no-member
2 changes: 2 additions & 0 deletions python/tvm/tir/schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@
from .schedule import BlockRV, ExprRV, LoopRV, Schedule, ScheduleError
from .state import ScheduleDebugMask, ScheduleState
from .trace import Trace

from . import analysis
58 changes: 58 additions & 0 deletions python/tvm/tir/schedule/analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Analysis used in TensorIR scheduling"""
from typing import List, Optional

from ..buffer import Buffer
from ..stmt import For
from ..expr import PrimExpr
from ..function import IndexMap

from . import _ffi_api


def suggest_index_map(
buffer: Buffer,
indices: List[PrimExpr],
loops: List[For],
predicate: PrimExpr,
) -> Optional[IndexMap]:
"""Provided the access pattern to a buffer, suggest one of the possible layout
transformation to minimize the locality of the access pattern.
Parameters
----------
buffer : Buffer
The buffer to be transformed.
indices : List[PrimExpr]
The access pattern to the buffer.
loops : List[For]
The loops above the buffer.
predicate : PrimExpr
The predicate of the access.
Returns
-------
index_map : Optional[IndexMap]
The suggested index map. None if no transformation is suggested.
"""
return _ffi_api.SuggestIndexMap( # type: ignore # pylint: disable=no-member
buffer,
indices,
loops,
predicate,
)
1 change: 0 additions & 1 deletion python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# under the License.
"""The TensorIR schedule class"""
from typing import Dict, List, Optional, Union
from typing_extensions import Annotated

from tvm._ffi import register_object as _register_object
from tvm.error import TVMError, register_error
Expand Down
69 changes: 67 additions & 2 deletions src/tir/ir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <tvm/runtime/registry.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -64,14 +65,78 @@ FuncType PrimFuncNode::func_type_annotation() const {

TVM_REGISTER_NODE_TYPE(PrimFuncNode);

Array<PrimExpr> IndexMapNode::Apply(const Array<PrimExpr>& inputs) const {
CHECK_EQ(inputs.size(), this->src_iters.size());
int n = inputs.size();
std::unordered_map<const VarNode*, PrimExpr> var_map;
var_map.reserve(n);
for (int i = 0; i < n; ++i) {
var_map.emplace(this->src_iters[i].get(), inputs[i]);
}
Array<PrimExpr> results;
results.reserve(this->tgt_iters.size());
for (PrimExpr result : this->tgt_iters) {
results.push_back(Substitute(std::move(result), var_map));
}
return results;
}

IndexMap::IndexMap(Array<Var> src_iters, Array<PrimExpr> tgt_iters) {
ObjectPtr<IndexMapNode> n = make_object<IndexMapNode>();
n->src_iters = std::move(src_iters);
n->tgt_iters = std::move(tgt_iters);
data_ = std::move(n);
}

IndexMap IndexMap::FromFunc(int ndim, runtime::TypedPackedFunc<Array<PrimExpr>(Array<Var>)> func) {
Array<Var> src_iters;
src_iters.reserve(ndim);
for (int i = 0; i < ndim; ++i) {
src_iters.push_back(Var("i" + std::to_string(i), DataType::Int(32)));
}
return IndexMap(src_iters, func(src_iters));
}

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IndexMapNode>([](const ObjectRef& node, ReprPrinter* p) {
const auto* n = node.as<IndexMapNode>();
ICHECK(n);
p->stream << "IndexMap: (";
for (int i = 0, total = n->src_iters.size(); i < total; ++i) {
if (i != 0) {
p->stream << ", ";
}
p->stream << n->src_iters[i];
}
p->stream << ") => ";
p->stream << "(";
for (int i = 0, total = n->tgt_iters.size(); i < total; ++i) {
if (i != 0) {
p->stream << ", ";
}
p->stream << n->tgt_iters[i];
}
p->stream << ")";
});

TVM_REGISTER_NODE_TYPE(IndexMapNode);
TVM_REGISTER_GLOBAL("tir.IndexMap")
.set_body_typed([](Array<Var> src_iters, Array<PrimExpr> tgt_iters) {
return IndexMap(src_iters, tgt_iters);
});
TVM_REGISTER_GLOBAL("tir.IndexMapFromFunc").set_body_typed(IndexMap::FromFunc);
TVM_REGISTER_GLOBAL("tir.IndexMapApply").set_body_method<IndexMap>(&IndexMapNode::Apply);

TensorIntrin::TensorIntrin(PrimFunc desc_func, PrimFunc intrin_func) {
// check the number of func var is equal
CHECK_EQ(desc_func->params.size(), intrin_func->params.size());
CHECK_EQ(desc_func->buffer_map.size(), intrin_func->buffer_map.size());

// check both functions' bodies are directly block
const auto* desc_realize = Downcast<BlockRealize>(desc_func->body)->block->body.as<BlockRealizeNode>();
const auto* intrin_realize = Downcast<BlockRealize>(intrin_func->body)->block->body.as<BlockRealizeNode>();
const auto* desc_realize =
Downcast<BlockRealize>(desc_func->body)->block->body.as<BlockRealizeNode>();
const auto* intrin_realize =
Downcast<BlockRealize>(intrin_func->body)->block->body.as<BlockRealizeNode>();
CHECK(desc_realize != nullptr) << "description function's body expect a directly block";
CHECK(intrin_realize != nullptr) << "intrinsic function's body expect a directly block";

Expand Down
Loading

0 comments on commit 67af07c

Please sign in to comment.