Skip to content

Commit

Permalink
[TIR][Analysis] Add SuggestIndexMap for layout rewriting (apache#10732)
Browse files Browse the repository at this point in the history
This PR added an analysis function `SuggestIndexMap` to analyze buffer access pattern and suggest index map for layout transformations.

Co-authored-by: Siyuan Feng <[email protected]>
Co-authored-by: Bohan Hou <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Xiyou Zhou <[email protected]>
  • Loading branch information
7 people authored and pfk-beta committed Apr 11, 2022
1 parent 8e9778f commit 60a6db2
Show file tree
Hide file tree
Showing 8 changed files with 411 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize

from .function import PrimFunc, TensorIntrin
from .function import PrimFunc, TensorIntrin, IndexMap

from .op import call_packed, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace
Expand Down
15 changes: 15 additions & 0 deletions python/tvm/tir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,3 +295,18 @@ def from_func(mapping_function: Callable, ndim: Optional[int] = None):

final_indices = mapping_function(*args)
return IndexMap(args, final_indices)

def map_indices(self, indices: List[PrimExpr]) -> List[PrimExpr]:
"""Apply the index map to a set of indices
Parameters
----------
indices : List[PriExpr]
The indices to be mapped
Returns
-------
result : List[PrimExpr]
The mapped indices
"""
return _ffi_api.IndexMapMapIndices(self, indices)
2 changes: 2 additions & 0 deletions python/tvm/tir/schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@
from .schedule import BlockRV, ExprRV, LoopRV, Schedule, ScheduleError
from .state import ScheduleDebugMask, ScheduleState
from .trace import Trace

from . import analysis
58 changes: 58 additions & 0 deletions python/tvm/tir/schedule/analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Analysis used in TensorIR scheduling"""
from typing import List, Optional

from ..buffer import Buffer
from ..stmt import For
from ..expr import PrimExpr
from ..function import IndexMap

from . import _ffi_api


def suggest_index_map(
buffer: Buffer,
indices: List[PrimExpr],
loops: List[For],
predicate: PrimExpr,
) -> Optional[IndexMap]:
"""Provided the access pattern to a buffer, suggest one of the possible layout
transformation to maximize the locality of the access pattern.
Parameters
----------
buffer : Buffer
The buffer to be transformed.
indices : List[PrimExpr]
The access pattern to the buffer.
loops : List[For]
The loops above the buffer.
predicate : PrimExpr
The predicate of the access.
Returns
-------
index_map : Optional[IndexMap]
The suggested index map. None if no transformation is suggested.
"""
return _ffi_api.SuggestIndexMap( # type: ignore # pylint: disable=no-member
buffer,
indices,
loops,
predicate,
)
2 changes: 2 additions & 0 deletions src/tir/ir/index_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,5 +201,7 @@ TVM_REGISTER_GLOBAL("tir.IndexMap")
return IndexMap(initial_indices, final_indices);
});

TVM_REGISTER_GLOBAL("tir.IndexMapMapIndices").set_body_method<IndexMap>(&IndexMapNode::MapIndices);

} // namespace tir
} // namespace tvm
14 changes: 14 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <tvm/arith/analyzer.h>
#include <tvm/ir/op.h>
#include <tvm/tir/index_map.h>
#include <tvm/tir/schedule/state.h>

#include <tuple>
Expand Down Expand Up @@ -520,6 +521,19 @@ bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const S
bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref,
const StmtSRef& loop_sref, bool preserve_unit_loops);

/*!
* \brief Provided the access pattern to a buffer, suggest one of the possible layout
* transformation to minimize the locality of the access pattern.
* \param buffer The buffer to be transformed
* \param indices The access pattern to the buffer
* \param loops The loops above the buffer
* \param predicate The predicate of the access
* \param analyzer Arithmetic analyzer
*/
Optional<IndexMap> SuggestIndexMap(const Buffer& buffer, const Array<PrimExpr>& indices,
const Array<For>& loops, const PrimExpr& predicate,
arith::Analyzer* analyzer);

/*!
* \brief Checks if the given AST contains the specific operators
* \param stmt The AST statement to be checked
Expand Down
212 changes: 212 additions & 0 deletions src/tir/schedule/analysis/layout.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../utils.h"

namespace tvm {
namespace tir {

/*!
* \brief Calculate the strides of the buffer
* \param buffer The buffer
* \return The strides
*/
Array<PrimExpr> GetStrides(const Buffer& buffer) {
if (!buffer->strides.empty()) {
ICHECK_EQ(buffer->strides.size(), buffer->shape.size());
return buffer->strides;
}
int ndim = buffer->shape.size();
if (ndim == 0) {
return {};
}
Array<PrimExpr> strides(ndim, PrimExpr{nullptr});
PrimExpr stride = make_const(buffer->DefaultIndexType(), 1);
for (int i = ndim - 1; i >= 0; --i) {
strides.Set(i, stride);
stride = stride * buffer->shape[i];
}
return strides;
}

/*!
* \brief Auxiliary class that collects the IterSplitExpr in the indexing pattern
* to help decision making in layout transformation
*/
class SplitExprCollector {
public:
/*!
* \brief The corresponding IterSplitExpr, simplified for our case
* The pattern is `source // lower_factor % extent * scale`
*/
struct SplitExpr {
/*! \brief The source variable */
Var source;
/*! \brief The lower factor of the split expression */
int64_t lower_factor;
/*! \brief The extent of the split expression */
int64_t extent;
};

/*!
* \brief Collect the split expressions in the indexing pattern
* \param index The indexing pattern
* \param input_iters The input iterators' domain
* \param predicate The predicate of the affine map
* \param require_bijective Whether the affine map is required to be bijective
* \param analyzer The analyzer
* \return The collected split expressions
*/
static std::vector<SplitExpr> Collect(const PrimExpr& index,
const Map<Var, Range>& input_iters, //
const PrimExpr& predicate, //
bool require_bijective, //
arith::Analyzer* analyzer) {
DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule()));
Array<arith::IterSumExpr> iter_sum_exprs = arith::DetectIterMap(
{analyzer->Simplify(index)}, input_iters, predicate, require_bijective, analyzer, diag_ctx);
if (iter_sum_exprs.empty()) {
return {};
}
ICHECK_EQ(iter_sum_exprs.size(), 1);
if (iter_sum_exprs[0]->args.size() == 0) {
return {};
}
SplitExprCollector collector;
collector.Visit(iter_sum_exprs[0]);
if (collector.failed_) {
return {};
}
return std::move(collector.exprs_);
}

private:
void Visit(const arith::IterSplitExpr& expr) {
if (const auto* var = expr->source->source.as<tir::VarNode>()) {
const int64_t* lower_factor = as_const_int(expr->lower_factor);
const int64_t* extent = as_const_int(expr->extent);
if (lower_factor == nullptr || extent == nullptr) {
failed_ = true;
return;
}
exprs_.push_back(SplitExpr{GetRef<Var>(var), *lower_factor, *extent});
} else if (const auto* iter_sum_expr = expr->source->source.as<arith::IterSumExprNode>()) {
Visit(GetRef<arith::IterSumExpr>(iter_sum_expr));
} else {
ICHECK(false) << "Unexpected type: " << expr->source->source->GetTypeKey();
}
}

void Visit(const arith::IterSumExpr& expr) {
for (const arith::IterSplitExpr& arg : expr->args) {
Visit(arg);
}
}

/*! \brief Whether the analysis failed */
bool failed_ = false;
/*! \brief The collected split expressions */
std::vector<SplitExpr> exprs_;
};

Optional<IndexMap> SuggestIndexMap(const Buffer& buffer, const Array<PrimExpr>& indices,
const Array<For>& loops, const PrimExpr& predicate,
arith::Analyzer* analyzer) {
int ndim = buffer->shape.size();
int n_loops = loops.size();
// Step 1. Collect the domains and indices of loop variables
Map<Var, Range> input_iters;
std::unordered_map<const VarNode*, int> var2id;
var2id.reserve(n_loops);
for (int i = 0; i < n_loops; ++i) {
const For& loop = loops[i];
input_iters.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
var2id.emplace(loop->loop_var.get(), i);
}
// Step 2. Calculate a functor that flattens a multi-dimensional index
auto f_flatten_index = [ndim, strides = GetStrides(buffer), dtype = buffer->DefaultIndexType()](
const Array<PrimExpr>& indices) -> PrimExpr {
PrimExpr flatten_index = make_const(dtype, 0);
for (int i = 0; i < ndim; ++i) {
flatten_index = flatten_index + strides[i] * indices[i];
}
return flatten_index;
};
// Step 3. Detect the IterSplitExpr of the indexing pattern
std::vector<SplitExprCollector::SplitExpr> split_exprs = SplitExprCollector::Collect(
/*index=*/f_flatten_index(indices), input_iters, predicate,
/*require_bijective=*/false, analyzer);
if (split_exprs.empty()) {
return NullOpt;
}
// Step 4. Sort the order of the split expressions
std::vector<int> order(split_exprs.size(), 0);
std::generate(order.begin(), order.end(), [n = 0]() mutable { return n++; });
std::sort(order.begin(), order.end(), [&split_exprs, &var2id](int _a, int _b) -> bool {
const SplitExprCollector::SplitExpr& a = split_exprs[_a];
const SplitExprCollector::SplitExpr& b = split_exprs[_b];
int a_var_id = var2id.at(a.source.get());
int b_var_id = var2id.at(b.source.get());
if (a_var_id != b_var_id) {
return a_var_id < b_var_id;
}
return a.lower_factor > b.lower_factor;
});
// Step 5. Create the indexing mapping
auto f_alter_layout = [f_flatten_index = std::move(f_flatten_index), //
split_exprs = std::move(split_exprs), //
order = std::move(order), //
shape = buffer->shape, //
analyzer //
](Array<Var> indices) -> Array<PrimExpr> {
ICHECK_EQ(indices.size(), shape.size());
for (int i = 0, n = indices.size(); i < n; ++i) {
analyzer->Bind(indices[i], Range::FromMinExtent(0, shape[i]));
}
PrimExpr index = f_flatten_index({indices.begin(), indices.end()});
int ndim = split_exprs.size();
// Step 5.1. Split the flattened index according to `split_exprs`
std::vector<PrimExpr> split;
split.reserve(ndim);
for (int i = ndim - 1; i >= 0; --i) {
index = analyzer->Simplify(index);
int64_t extent = split_exprs[i].extent;
split.push_back(analyzer->Simplify(floormod(index, extent)));
index = floordiv(index, extent);
}
std::reverse(split.begin(), split.end());
// Step 5.2. Reorder the indexing pattern according to `order`
Array<PrimExpr> results;
results.reserve(ndim);
for (int i = 0; i < ndim; ++i) {
results.push_back(split[order[i]]);
}
return results;
};
return IndexMap::FromFunc(ndim, f_alter_layout);
}

TVM_REGISTER_GLOBAL("tir.schedule.SuggestIndexMap")
.set_body_typed([](Buffer buffer, Array<PrimExpr> indices, Array<For> loops,
PrimExpr predicate) {
arith::Analyzer analyzer;
return SuggestIndexMap(buffer, indices, loops, predicate, &analyzer);
});

} // namespace tir
} // namespace tvm
Loading

0 comments on commit 60a6db2

Please sign in to comment.