Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Utility function to decide loop mapping for auto tensorization #11050

Merged
merged 19 commits into from
Apr 20, 2022
Merged
33 changes: 32 additions & 1 deletion python/tvm/tir/schedule/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@
"""Analysis used in TensorIR scheduling"""
from typing import List, Optional

import tvm._ffi
from tvm.runtime import Object

from ..buffer import Buffer
from ..stmt import For
from ..expr import PrimExpr
from ..function import IndexMap
from ..function import IndexMap, PrimFunc

from . import _ffi_api
from .schedule import Schedule, BlockRV


def suggest_index_map(
Expand Down Expand Up @@ -56,3 +60,30 @@ def suggest_index_map(
loops,
predicate,
)


@tvm._ffi.register_object("tir.schedule.TensorizeInfo")
class TensorizeInfo(Object):
"""Necessary information used for tensorization."""


def get_tensorize_loop_mapping(
sch: Schedule, block: BlockRV, desc_func: PrimFunc
) -> Optional[TensorizeInfo]:
"""Establish a mapping between loops in a target block and an intrinsic description

Parameters
----------
sch : Schedule
The schedule to be tensorized
block : BlockRV
The target block to match against
desc_func : PrimFunc
The prim func describing the computation to be tensorized

Returns
-------
tensorize_info : Optional[TensorizeInfo]
TensorizeInfo structure if a valid mapping is found, None otherwise
"""
return _ffi_api.GetTensorizeLoopMapping(sch, block, desc_func) # type: ignore
12 changes: 12 additions & 0 deletions python/tvm/tir/stmt_functor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,18 @@ def post_order_visit(stmt, fvisit):
return _ffi_api.PostOrderVisit(stmt, fvisit) # type: ignore


def pre_order_visit(stmt, fvisit):
"""Recursive pre-order visit on stmt AST, applying fvisit on each node.
If fvisit returns False, it won't visit the children of the node.

Parameters
----------
fvisit: function of the signature Object -> bool
The visitor function.
"""
return _ffi_api.PreOrderVisit(stmt, fvisit) # type: ignore


def substitute(node, vmap):
"""Substitute the var specified by vmap.

Expand Down
4 changes: 4 additions & 0 deletions src/tir/ir/stmt_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,10 @@ TVM_REGISTER_GLOBAL("tir.PostOrderVisit").set_body_typed([](ObjectRef node, Pack
tir::PostOrderVisit(node, [f](const ObjectRef& n) { f(n); });
});

TVM_REGISTER_GLOBAL("tir.PreOrderVisit").set_body_typed([](ObjectRef node, PackedFunc f) {
tir::PreOrderVisit(node, [f](const ObjectRef& n) { return f(n); });
});

TVM_REGISTER_GLOBAL("tir.Substitute")
.set_body_typed([](ObjectRef node, Map<Var, PrimExpr> vmap) -> ObjectRef {
if (node->IsInstance<StmtNode>()) {
Expand Down
33 changes: 33 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,39 @@ Array<arith::IntSet> AnalyzeRegionLowerBound(const BufferRegion& region, const P
const StmtSRef& dom_high_exclusive,
arith::Analyzer* analyzer);

/*! \brief Necessary information used for tensorization */
class TensorizeInfoNode : public Object {
public:
/*! \brief Maps loops in a target block to the ones in an intrinsic description */
Map<tir::StmtSRef, tir::For> loop_map;
/*! \brief Maps loops in an intrinsic description to its index, outer to inner */
Map<tir::For, Integer> desc_loop_indexer;

void VisitAttrs(AttrVisitor* v) {
v->Visit("loop_map", &loop_map);
v->Visit("desc_loop_indexer", &desc_loop_indexer);
}

static constexpr const char* _type_key = "tir.schedule.TensorizeInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorizeInfoNode, Object);
};

class TensorizeInfo : public ObjectRef {
public:
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorizeInfo, ObjectRef, TensorizeInfoNode);
};

/*!
* \brief Establish a mapping between loops in a target block and an intrinsic description
* \param self The schedule state to be tensorized
* \param block_sref The target block to match against
* \param desc_func The prim func describing the computation to be tensorized
* \return TensorizeInfo structure if a valid mapping is found, NullOpt otherwise
*/
Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
const tir::StmtSRef& block_sref,
const tir::PrimFunc& desc_func);

} // namespace tir
} // namespace tvm

Expand Down
168 changes: 166 additions & 2 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/runtime/container/optional.h>
#include <tvm/tir/expr.h>

#include "../utils.h"

namespace tvm {
Expand Down Expand Up @@ -492,8 +495,7 @@ void CheckNotOutputBlock(const ScheduleState& self, const StmtSRef& block_sref,
}
}

std::vector<IterVarType> GetBlockVarTypes(const StmtSRef& block_sref) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
std::vector<IterVarType> GetBlockVarTypes(const BlockNode* block) {
std::vector<IterVarType> results;
results.reserve(block->iter_vars.size());
for (const IterVar& iter_var : block->iter_vars) {
Expand All @@ -502,6 +504,11 @@ std::vector<IterVarType> GetBlockVarTypes(const StmtSRef& block_sref) {
return results;
}

std::vector<IterVarType> GetBlockVarTypes(const StmtSRef& block_sref) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
return GetBlockVarTypes(block);
}

bool IsWriteCache(const StmtSRef& block_sref) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
if (block->writes.size() != 1) {
Expand Down Expand Up @@ -2028,5 +2035,162 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, //
}
}

TVM_REGISTER_NODE_TYPE(TensorizeInfoNode);

Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
const tir::StmtSRef& block_sref,
const tir::PrimFunc& desc_func) {
arith::Analyzer analyzer;
const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref);
// Step 1. Analyze desc_func, extract its block, loops and loop vars
const tir::BlockRealizeNode* desc_block = nullptr;
std::vector<const tir::ForNode*> desc_loops;
std::unordered_set<const tir::VarNode*> desc_loop_vars;
const auto* desc_scope_realize = desc_func->body.as<tir::BlockRealizeNode>();
ICHECK(desc_scope_realize);
{
auto f_visit = [&desc_block, &desc_loops, &desc_loop_vars,
&analyzer](const ObjectRef& obj) -> bool {
// Extract the block
if (const auto* block = obj.as<tir::BlockRealizeNode>()) {
desc_block = block;
return false;
}
// Extract loops
if (const auto* loop = obj.as<tir::ForNode>()) {
desc_loops.push_back(loop);
desc_loop_vars.insert(loop->loop_var.get());
if (!analyzer.CanProve(loop->min == 0)) {
return false;
}
}
return true;
};
tir::PostOrderVisit(desc_scope_realize->block->body, f_visit);
std::reverse(desc_loops.begin(), desc_loops.end());
ICHECK(desc_block);
}
// Step 2. Collect loops from block_sref
const tir::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false);
const tir::BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref);
std::vector<const tir::ForNode*> block_loops;
std::unordered_set<const tir::VarNode*> block_loop_vars;
{
for (const tir::StmtSRefNode* loop_sref = block_sref->parent;; loop_sref = loop_sref->parent) {
const auto* loop = loop_sref->StmtAs<tir::ForNode>();
if (loop == nullptr || loop->body->IsInstance<tir::SeqStmtNode>()) {
break;
}
block_loops.push_back(loop);
block_loop_vars.insert(loop->loop_var.get());
if (!analyzer.CanProve(loop->min == 0)) {
return NullOpt;
}
}
std::reverse(block_loops.begin(), block_loops.end());
}
// Step 3. Map from block loops to desc block loops
ObjectPtr<TensorizeInfoNode> ret = make_object<TensorizeInfoNode>();
const int n_block_vars = block->iter_values.size();
const int n_desc_vars = desc_block->iter_values.size();
const int offset = n_block_vars - n_desc_vars;

if (offset < 0) {
return NullOpt;
}

const std::vector<IterVarType> iter_types_block = GetBlockVarTypes(block_sref);
const std::vector<IterVarType> iter_types_desc = GetBlockVarTypes(desc_block->block.get());

ICHECK(desc_loops.size() == static_cast<size_t>(n_desc_vars));
ICHECK(block_loops.size() == iter_types_block.size());

// We assume that the orders of iter_vars in the target and the desc block are consistent.
Copy link
Member Author

@masahi masahi Apr 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i.e., no matter what the permutation of loop is, we should always have

i, j, k = T.axis.remap("SSR", [i0, i1, i2])

for GEMM.

I think this is a reasonable assumption. Correct me if I'm wrong @spectrometerHBH @junrushao1994 @vinx13

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree this is a reasonable assumption. Though there might be corner cases, it covers all of the current use cases

// Based on that assumption, the following logic supports arbitrary permutations of a loop order,
// such as

// for k:
// for i:
// for j:
// C[i, j] += A[i, k] * B[k, j]

// or

// for i:
// for j:
// for k:
// C[i, j] += A[i, k] * B[k, j]

int next_block_ind = block_loops.size() - 1;
for (int i_desc = n_desc_vars - 1; i_desc >= 0; --i_desc) {
// Step 3.1. Find the corresponding loop of the i_desc-th block var of desc
const PrimExpr& desc_bind = desc_block->iter_values[i_desc];
const tir::ForNode* desc_loop = nullptr;
IterVarType iter_type_desc = iter_types_desc[i_desc];
for (int i = 0, n = desc_loops.size(); i < n; ++i) {
// Check if desc_bind = loops[i]->loop_var + stuff-irrelevant-of-loop-vars
PrimExpr r = analyzer.Simplify(desc_bind - desc_loops[i]->loop_var);
if (!UsesVar(r,
[&desc_loop_vars](const VarNode* var) { return desc_loop_vars.count(var); })) {
desc_loop = desc_loops[i];
iter_type_desc = iter_types_desc[i];
break;
}
}
if (desc_loop == nullptr || desc_loop->extent.as<IntImmNode>() == nullptr) {
return NullOpt;
}

const IntImmNode* int_desc_extent = desc_loop->extent.as<IntImmNode>();

// Step 3.2. Find the corresponding iter_value of the target block with a matching iterator type
PrimExpr block_bind;
for (int i = next_block_ind; i >= 0; --i) {
if (iter_types_block[i] == iter_type_desc) {
next_block_ind = i - 1;
block_bind = block->iter_values[i];
break;
}
}

if (!block_bind.defined()) return NullOpt;

// Step 3.3. Find the corresponding loop of the target block
for (int i = 0, n = block_loops.size(); i < n; ++i) {
// Check if block_bind = block_loops[i]->loop_var + stuff-irrelevant-of-loop-vars
PrimExpr r = analyzer.Simplify(block_bind - block_loops[i]->loop_var);
if (!UsesVar(r,
[&block_loop_vars](const VarNode* var) { return block_loop_vars.count(var); })) {
const IntImmNode* int_block_extent = block_loops[i]->extent.as<IntImmNode>();

// Check divisibility
if (!int_block_extent || int_block_extent->value % int_desc_extent->value != 0) {
return NullOpt;
}

const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loops[i]];
auto it = ret->loop_map.find(block_loop_sref);
if (it == ret->loop_map.end()) {
ret->loop_map.Set(block_loop_sref, GetRef<tir::For>(desc_loop));
} else if ((*it).second.get() != desc_loop) {
return NullOpt;
}

break;
}
}
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic here is very different from the one in the original code https://github.com/spectrometerHBH/tvm/blob/auto-tensorization/src/tir/schedule/analysis/analysis.cc#L1246. I was not able to understand why the original code has been written that way and it didn't work for the case where matching loops in the target block are not in the innermost positions (conv2d NCHWc on CPU, a test in

def test_get_tensorize_loop_mapping_conv2d_nchwc_vnni():
).

I think my change is simple and obvious. The condition for a match is (1) divisibility of loop extent and (2) matching iterator types (reduction vs spatial). Mapping is determined starting from the innermost axis.

Please have a look at this change carefully, and let me know if I need to bring back some logic in the original code @spectrometerHBH @vinx13

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would love to have @spectrometerHBH review this change before merging

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The goal of the original mapping is to support

for k:
  for i:
    for j:
        C[i, j] += A[i, k] * B[k, j]

where loops are not in the same order as the tensor intrinsic description function.

Copy link
Contributor

@spectrometerHBH spectrometerHBH Apr 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But it also makes sense if we don't support such cases for this PR. So I approve it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @spectrometerHBH, I now understand the original code and was able to integrate the original logic to support loop permutations. Please have a look at the current diff, also cc @vinx13 @Hzfengsy @MasterJH5574

The key difference between the original code and the code I submitted yesterday was that, my code was looking at only the loop nest (ForNode) to determine the mapping, while @spectrometerHBH's mapping logic is based on iter_var/value of the block (so invariant to the order of the loop nest).


for (int i = 0, n = desc_loops.size(); i < n; ++i) {
ret->desc_loop_indexer.Set(GetRef<tir::For>(desc_loops[i]), Integer(i));
}
return TensorizeInfo(ret);
}

TVM_REGISTER_GLOBAL("tir.schedule.GetTensorizeLoopMapping")
.set_body_typed([](Schedule sch, BlockRV block, PrimFunc desc_func) {
return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), desc_func);
});

} // namespace tir
} // namespace tvm
Loading