Skip to content

Commit

Permalink
[MetaSchedule] Post Processor: Rewrite Reduction Block (apache#10013)
Browse files Browse the repository at this point in the history
Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Xiyou Zhou <[email protected]>
Co-authored-by: Bohan Hou <[email protected]>
Co-authored-by: Siyuan Feng <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Wuwei Lin <[email protected]>

Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Xiyou Zhou <[email protected]>
Co-authored-by: Bohan Hou <[email protected]>
Co-authored-by: Siyuan Feng <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Wuwei Lin <[email protected]>
  • Loading branch information
7 people authored and yuanfz98 committed Jan 24, 2022
1 parent c01994f commit d759a72
Show file tree
Hide file tree
Showing 6 changed files with 454 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/postproc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@
# under the License.
"""The tvm.meta_schedule.postproc package."""
from .postproc import Postproc, PyPostproc
from .rewrite_reduction_block import RewriteReductionBlock
from .verify_gpu_code import VerifyGPUCode
31 changes: 31 additions & 0 deletions python/tvm/meta_schedule/postproc/rewrite_reduction_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""A postprocessor that rewrites reduction block by moving the init block out."""

from tvm._ffi.registry import register_object
from .. import _ffi_api
from .postproc import Postproc


@register_object("meta_schedule.RewriteReductionBlock")
class RewriteReductionBlock(Postproc):
"""A postprocessor that rewrites reduction block by moving the init block out."""

def __init__(self) -> None:
self.__init_handle_by_constructor__(
_ffi_api.PostprocRewriteReductionBlock, # type: ignore # pylint: disable=no-member
)
157 changes: 157 additions & 0 deletions src/meta_schedule/postproc/rewrite_reduction_block.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../utils.h"

namespace tvm {
namespace tir {

/*! \brief The visitor that finds all the reduction block to be decomposed */
struct ReductionBlockFinder : private StmtVisitor {
public:
/*! \brief Find all the reduction blocks that should be decomposed */
static std::vector<std::pair<StmtSRef, String>> Find(const ScheduleState& self) {
std::vector<std::pair<StmtSRef, String>> results;
for (const auto& kv : self->mod->functions) {
GlobalVar g_var = kv.first;
BaseFunc base_func = kv.second;
if (const auto* prim_func = base_func.as<PrimFuncNode>()) {
ReductionBlockFinder finder;
finder(prim_func->body);
for (const BlockNode* block : finder.results_) {
results.emplace_back(self->stmt2ref.at(block), g_var->name_hint);
}
}
}
return results;
}

private:
void VisitStmt_(const ForNode* loop) final {
runtime::ThreadScope thread_scope = GetThreadScope(loop);
if (IsThreadIdx(thread_scope) || IsBlockIdx(thread_scope)) {
thread_bound_loop_vars_.insert(loop->loop_var.get());
}
StmtVisitor::VisitStmt_(loop);
}

void VisitStmt_(const BlockRealizeNode* realize) final {
if (realize->block->init.defined() && AllReductionIterVarAreUnbound(realize)) {
results_.push_back(realize->block.get());
}
StmtVisitor::VisitStmt_(realize);
}

bool AllReductionIterVarAreUnbound(const BlockRealizeNode* realize) const {
if (thread_bound_loop_vars_.empty()) {
return true;
}
auto f_find = [this](const VarNode* var) -> bool { return thread_bound_loop_vars_.count(var); };
const BlockNode* block = realize->block.get();
ICHECK_EQ(block->iter_vars.size(), realize->iter_values.size());
int n = block->iter_vars.size();
for (int i = 0; i < n; ++i) {
IterVar iter_var = block->iter_vars[i];
PrimExpr binding = realize->iter_values[i];
if (iter_var->iter_type == tir::kCommReduce) {
if (UsesVar(binding, f_find)) {
return false;
}
}
}
return true;
}

/*! \brief The results of the collection */
std::vector<const BlockNode*> results_;
/*! \brief Loop variables that are bound to threads */
std::unordered_set<const VarNode*> thread_bound_loop_vars_;
};

/*!
* \brief Find the innermost loop that the `init` of the input block could be decomposed to
* \param block_sref The StmtSRef of the block to be decomposed
* \return The index of the innermost loop where the `init` of the input block could be decomposed,
* or -1 if the `init` does not need to be decomposed.
*/
int FindDecomposePoint(const StmtSRef& block_sref) {
Array<StmtSRef> loop_srefs = GetLoops(block_sref);
int n = loop_srefs.size();
for (int i = 0; i < n; ++i) {
if (GetLoopIterType(loop_srefs[i]) != IterVarType::kDataPar) {
return i;
}
}
return -1;
}

} // namespace tir
} // namespace tvm

namespace tvm {
namespace meta_schedule {

/*! \brief Rewrite reduction block by moving the init block out */
class RewriteReductionBlockNode : public PostprocNode {
public:
// Inherited from PostprocNode
void InitializeWithTuneContext(const TuneContext& context) final {}
// Inherited from PostprocNode
bool Apply(const tir::Schedule& sch) final;

void VisitAttrs(tvm::AttrVisitor* v) {}

static constexpr const char* _type_key = "meta_schedule.RewriteReductionBlock";
TVM_DECLARE_FINAL_OBJECT_INFO(RewriteReductionBlockNode, PostprocNode);
};

bool RewriteReductionBlockNode::Apply(const tir::Schedule& sch) {
for (;;) {
std::vector<std::pair<tir::StmtSRef, String>> results =
tir::ReductionBlockFinder::Find(sch->state());
int rewritten = 0;
for (const auto& kv : results) {
const tir::StmtSRef& block_sref = kv.first;
const String& global_var_name = kv.second;
int decompose_point = tir::FindDecomposePoint(block_sref);
if (decompose_point == -1) {
continue;
}
tir::BlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name);
Array<tir::LoopRV> loop_rvs = sch->GetLoops(block_rv);
tir::BlockRV init_block_rv = sch->DecomposeReduction(block_rv, loop_rvs[decompose_point]);
++rewritten;
}
if (rewritten == 0) {
break;
}
}
return true;
}

Postproc Postproc::RewriteReductionBlock() {
ObjectPtr<RewriteReductionBlockNode> n = make_object<RewriteReductionBlockNode>();
return Postproc(n);
}

TVM_REGISTER_NODE_TYPE(RewriteReductionBlockNode);
TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteReductionBlock")
.set_body_typed(Postproc::RewriteReductionBlock);

} // namespace meta_schedule
} // namespace tvm
13 changes: 13 additions & 0 deletions src/meta_schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,19 @@ inline std::string Concat(const Array<String>& strs, const std::string& delim) {
return os.str();
}

/*!
* \brief Get the BlockRV from a block StmtSRef
* \param sch The schedule
* \param block_sref The block StmtSRef
* \param global_var_name The global variable name
* \return The BlockRV
*/
inline tir::BlockRV GetRVFromSRef(const tir::Schedule& sch, const tir::StmtSRef& block_sref,
const String& global_var_name) {
const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
return sch->GetBlock(block->name_hint, global_var_name);
}

/*!
* \brief A helper data structure that replays a trace and collects failure counts
* for each postprocessor
Expand Down
30 changes: 30 additions & 0 deletions src/tir/schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,36 @@ inline IterVar IterVarFromLoop(const For& loop, String name, IterVarType iter_va
Var(std::move(name), loop->loop_var.dtype()), iter_var_type);
}

/*!
* \brief Get the thread scope bound to the specific loop
* \param loop The loop to be inspected
* \return The thread scope bound to the loop
*/
inline runtime::ThreadScope GetThreadScope(const ForNode* loop) {
if (loop->kind == ForKind::kThreadBinding) {
return runtime::ThreadScope::Create(loop->thread_binding.value()->thread_tag);
}
return runtime::ThreadScope{-1, -1};
}

/*!
* \brief Check if the thread scope is blockIdx
* \param thread_scope The thread scope to be checked
* \return True if the thread scope is blockIdx
*/
inline bool IsBlockIdx(const runtime::ThreadScope& thread_scope) {
return thread_scope.rank == 0; // The rank of blockIdx is 0
}

/*!
* \brief Check if the thread scope is threadIdx
* \param thread_scope The thread scope to be checked
* \return True if the thread scope is threadIdx
*/
inline bool IsThreadIdx(const runtime::ThreadScope& thread_scope) {
return thread_scope.rank == 1 && thread_scope.dim_index >= 0;
}

/******** Integer set ********/

/*!
Expand Down
Loading

0 comments on commit d759a72

Please sign in to comment.