forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MetaSchedule] Post Processor: Rewrite Reduction Block (apache#10013)
Co-authored-by: Junru Shao <[email protected]> Co-authored-by: Xiyou Zhou <[email protected]> Co-authored-by: Bohan Hou <[email protected]> Co-authored-by: Siyuan Feng <[email protected]> Co-authored-by: Hongyi Jin <[email protected]> Co-authored-by: Wuwei Lin <[email protected]> Co-authored-by: Junru Shao <[email protected]> Co-authored-by: Xiyou Zhou <[email protected]> Co-authored-by: Bohan Hou <[email protected]> Co-authored-by: Siyuan Feng <[email protected]> Co-authored-by: Hongyi Jin <[email protected]> Co-authored-by: Wuwei Lin <[email protected]>
- Loading branch information
Showing
6 changed files
with
454 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
31 changes: 31 additions & 0 deletions
31
python/tvm/meta_schedule/postproc/rewrite_reduction_block.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""A postprocessor that rewrites reduction block by moving the init block out.""" | ||
|
||
from tvm._ffi.registry import register_object | ||
from .. import _ffi_api | ||
from .postproc import Postproc | ||
|
||
|
||
@register_object("meta_schedule.RewriteReductionBlock") | ||
class RewriteReductionBlock(Postproc): | ||
"""A postprocessor that rewrites reduction block by moving the init block out.""" | ||
|
||
def __init__(self) -> None: | ||
self.__init_handle_by_constructor__( | ||
_ffi_api.PostprocRewriteReductionBlock, # type: ignore # pylint: disable=no-member | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
#include "../utils.h" | ||
|
||
namespace tvm { | ||
namespace tir { | ||
|
||
/*! \brief The visitor that finds all the reduction block to be decomposed */ | ||
struct ReductionBlockFinder : private StmtVisitor { | ||
public: | ||
/*! \brief Find all the reduction blocks that should be decomposed */ | ||
static std::vector<std::pair<StmtSRef, String>> Find(const ScheduleState& self) { | ||
std::vector<std::pair<StmtSRef, String>> results; | ||
for (const auto& kv : self->mod->functions) { | ||
GlobalVar g_var = kv.first; | ||
BaseFunc base_func = kv.second; | ||
if (const auto* prim_func = base_func.as<PrimFuncNode>()) { | ||
ReductionBlockFinder finder; | ||
finder(prim_func->body); | ||
for (const BlockNode* block : finder.results_) { | ||
results.emplace_back(self->stmt2ref.at(block), g_var->name_hint); | ||
} | ||
} | ||
} | ||
return results; | ||
} | ||
|
||
private: | ||
void VisitStmt_(const ForNode* loop) final { | ||
runtime::ThreadScope thread_scope = GetThreadScope(loop); | ||
if (IsThreadIdx(thread_scope) || IsBlockIdx(thread_scope)) { | ||
thread_bound_loop_vars_.insert(loop->loop_var.get()); | ||
} | ||
StmtVisitor::VisitStmt_(loop); | ||
} | ||
|
||
void VisitStmt_(const BlockRealizeNode* realize) final { | ||
if (realize->block->init.defined() && AllReductionIterVarAreUnbound(realize)) { | ||
results_.push_back(realize->block.get()); | ||
} | ||
StmtVisitor::VisitStmt_(realize); | ||
} | ||
|
||
bool AllReductionIterVarAreUnbound(const BlockRealizeNode* realize) const { | ||
if (thread_bound_loop_vars_.empty()) { | ||
return true; | ||
} | ||
auto f_find = [this](const VarNode* var) -> bool { return thread_bound_loop_vars_.count(var); }; | ||
const BlockNode* block = realize->block.get(); | ||
ICHECK_EQ(block->iter_vars.size(), realize->iter_values.size()); | ||
int n = block->iter_vars.size(); | ||
for (int i = 0; i < n; ++i) { | ||
IterVar iter_var = block->iter_vars[i]; | ||
PrimExpr binding = realize->iter_values[i]; | ||
if (iter_var->iter_type == tir::kCommReduce) { | ||
if (UsesVar(binding, f_find)) { | ||
return false; | ||
} | ||
} | ||
} | ||
return true; | ||
} | ||
|
||
/*! \brief The results of the collection */ | ||
std::vector<const BlockNode*> results_; | ||
/*! \brief Loop variables that are bound to threads */ | ||
std::unordered_set<const VarNode*> thread_bound_loop_vars_; | ||
}; | ||
|
||
/*! | ||
* \brief Find the innermost loop that the `init` of the input block could be decomposed to | ||
* \param block_sref The StmtSRef of the block to be decomposed | ||
* \return The index of the innermost loop where the `init` of the input block could be decomposed, | ||
* or -1 if the `init` does not need to be decomposed. | ||
*/ | ||
int FindDecomposePoint(const StmtSRef& block_sref) { | ||
Array<StmtSRef> loop_srefs = GetLoops(block_sref); | ||
int n = loop_srefs.size(); | ||
for (int i = 0; i < n; ++i) { | ||
if (GetLoopIterType(loop_srefs[i]) != IterVarType::kDataPar) { | ||
return i; | ||
} | ||
} | ||
return -1; | ||
} | ||
|
||
} // namespace tir | ||
} // namespace tvm | ||
|
||
namespace tvm { | ||
namespace meta_schedule { | ||
|
||
/*! \brief Rewrite reduction block by moving the init block out */ | ||
class RewriteReductionBlockNode : public PostprocNode { | ||
public: | ||
// Inherited from PostprocNode | ||
void InitializeWithTuneContext(const TuneContext& context) final {} | ||
// Inherited from PostprocNode | ||
bool Apply(const tir::Schedule& sch) final; | ||
|
||
void VisitAttrs(tvm::AttrVisitor* v) {} | ||
|
||
static constexpr const char* _type_key = "meta_schedule.RewriteReductionBlock"; | ||
TVM_DECLARE_FINAL_OBJECT_INFO(RewriteReductionBlockNode, PostprocNode); | ||
}; | ||
|
||
bool RewriteReductionBlockNode::Apply(const tir::Schedule& sch) { | ||
for (;;) { | ||
std::vector<std::pair<tir::StmtSRef, String>> results = | ||
tir::ReductionBlockFinder::Find(sch->state()); | ||
int rewritten = 0; | ||
for (const auto& kv : results) { | ||
const tir::StmtSRef& block_sref = kv.first; | ||
const String& global_var_name = kv.second; | ||
int decompose_point = tir::FindDecomposePoint(block_sref); | ||
if (decompose_point == -1) { | ||
continue; | ||
} | ||
tir::BlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name); | ||
Array<tir::LoopRV> loop_rvs = sch->GetLoops(block_rv); | ||
tir::BlockRV init_block_rv = sch->DecomposeReduction(block_rv, loop_rvs[decompose_point]); | ||
++rewritten; | ||
} | ||
if (rewritten == 0) { | ||
break; | ||
} | ||
} | ||
return true; | ||
} | ||
|
||
Postproc Postproc::RewriteReductionBlock() { | ||
ObjectPtr<RewriteReductionBlockNode> n = make_object<RewriteReductionBlockNode>(); | ||
return Postproc(n); | ||
} | ||
|
||
TVM_REGISTER_NODE_TYPE(RewriteReductionBlockNode); | ||
TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteReductionBlock") | ||
.set_body_typed(Postproc::RewriteReductionBlock); | ||
|
||
} // namespace meta_schedule | ||
} // namespace tvm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.