Skip to content

Commit

Permalink
[MetaSchedule] Schedule Rule: Cross Thread Reduction (#9994)
Browse files Browse the repository at this point in the history
Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Xiyou Zhou <[email protected]>
Co-authored-by: Bohan Hou <[email protected]>
Co-authored-by: Siyuan Feng <[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Wuwei Lin <[email protected]>
  • Loading branch information
7 people authored Jan 21, 2022
1 parent 81b66e6 commit 1ac01b4
Show file tree
Hide file tree
Showing 9 changed files with 605 additions and 4 deletions.
7 changes: 7 additions & 0 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,13 @@ class ScheduleRule : public runtime::ObjectRef {
*/
TVM_DLL static ScheduleRule AddRFactor(int max_jobs_per_core, //
Optional<Integer> max_innermost_factor);
/*!
* \brief Create a schedule rule which applies cross-thread reduction to some reduction blocks
* correspondingly when needed
* \param thread_extents Candidates of thread axis extent (values are required to be positive).
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule CrossThreadReduction(Array<Integer> thread_extents);
/*!
* \brief A rule that randomly select a compute-at location for a free block
* \return The rule created
Expand Down
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/schedule_rule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@
"""
from .add_rfactor import AddRFactor
from .auto_inline import AutoInline
from .cross_thread_reduction import CrossThreadReduction
from .schedule_rule import PyScheduleRule, ScheduleRule
from .random_compute_location import RandomComputeLocation
41 changes: 41 additions & 0 deletions python/tvm/meta_schedule/schedule_rule/cross_thread_reduction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Rules which apply cross-thread reduction to some reduction blocks correspondingly when needed"""
from typing import List

from tvm._ffi import register_object

from .. import _ffi_api
from .schedule_rule import ScheduleRule


@register_object("meta_schedule.CrossThreadReduction")
class CrossThreadReduction(ScheduleRule):
"""A schedule rule which applies cross-thread reduction to some reduction blocks
correspondingly when needed
Parameters
----------
thread_extents: List[int]
Candidates of thread axis extent (values are required to be positive).
"""

def __init__(self, thread_extents: List[int]) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ScheduleRuleCrossThreadReduction, # type: ignore # pylint: disable=no-member
thread_extents,
)
8 changes: 8 additions & 0 deletions python/tvm/meta_schedule/testing/schedule_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from tvm.meta_schedule.schedule_rule import (
AddRFactor,
AutoInline,
CrossThreadReduction,
ScheduleRule,
)
from tvm.target import Target
Expand Down Expand Up @@ -53,3 +54,10 @@ def add_rfactor(target: Target) -> ScheduleRule:
if target.kind.name == "llvm":
return AddRFactor(max_jobs_per_core=16, max_innermost_factor=64)
raise NotImplementedError(f"{target.kind.name} is not supported")


def cross_thread_reduction(target: Target) -> ScheduleRule:
"""Default schedule rules for with cross-thread reduction"""
if target.kind.name == "cuda":
return CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512])
raise NotImplementedError(f"{target.kind.name} is not supported")
285 changes: 285 additions & 0 deletions src/meta_schedule/schedule_rule/cross_thread_reduction.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../utils.h"

namespace tvm {
namespace meta_schedule {

class CrossThreadReductionNode : public ScheduleRuleNode {
public:
// Inherited from ScheduleRuleNode
void InitializeWithTuneContext(const TuneContext& context) final {
ICHECK(context->target.defined());
Target target = context->target.value();

Optional<Integer> opt_max_threads_per_block = target->GetAttr<Integer>("max_threads_per_block");
Optional<Integer> opt_warp_size = target->GetAttr<Integer>("thread_warp_size");

if (!opt_max_threads_per_block.defined()) {
LOG(WARNING) << "Target does not have attribute \"max_threads_per_block\", therefore the "
"rule CrossThreadReduction will not be applied";
}
if (!opt_warp_size.defined()) {
LOG(WARNING) << "Target does not have attribute \"thread_warp_size\", therefore the rule "
"CrossThreadReduction will not be applied";
}
max_threads_per_block = opt_max_threads_per_block.value_or(Integer(-1))->value;
warp_size = opt_warp_size.value_or(Integer(-1))->value;
}

// Inherited from ScheduleRuleNode
Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final {
// Step 0. Check the conditions of this rule.
if (max_threads_per_block == -1 || warp_size == -1) {
return {sch};
}
const tir::StmtSRef& block_sref = sch->GetSRef(block_rv);
if (!NeedsRFactorOrCrossThreadReduction(sch->state(), block_sref, max_threads_per_block,
warp_size)) {
return {sch};
}

// Step 1. Make a copy of the original schedule. The new copy is used for scheduling.
tir::Schedule tmp_sch = sch->Copy();
tmp_sch->Seed(sch->ForkSeed());

// Step 2. Check the opportunity for block fusion. We say "fusible", if we can compute-at the
// block to its consumers. We want to fuse as much as possible because it results in
// significantly faster schedule.
bool fusible = false;
// `target_loop` is the loop position where the input block will be computed at.
tir::LoopRV target_loop{nullptr};
// `target_block` is the consumer block that we want to compute-at the input block to.
tir::BlockRV target_block{nullptr};
// `tgt_block_innermost_loop` is the innermost loop outside the target block.
tir::LoopRV tgt_block_innermost_loop{nullptr};

std::tie(fusible, target_loop, target_block, tgt_block_innermost_loop) =
GetComputeTargetLoopAndBlock(tmp_sch, block_rv);

// Step 3. Try block fusion.
int n_candidate = static_cast<int>(thread_extents.size());
Array<FloatImm> probs(n_candidate, FloatImm(DataType::Float(64), 1.0 / n_candidate));
tir::ExprRV thread_extent = tmp_sch->SampleCategorical(thread_extents, probs);
if (fusible) {
ICHECK(target_block.defined());
ICHECK(target_loop.defined());

// Step 3.1.
// - If the outer loops of `target_block` haven't been bound to "threadIdx.x", we should first
// bound the innermost outer loop of `target_block` to threadIdx. Possibly we need to split
// the loop before binding.
// - Otherwise, we search for the extent of "threadIdx.x" and use it as the split factor.
if (!InThreadScope(tmp_sch, target_block)) {
const Array<tir::LoopRV>& split_res =
tmp_sch->Split(tgt_block_innermost_loop, {NullOpt, thread_extent});
tmp_sch->Bind(split_res[1], "threadIdx.x");
if (tgt_block_innermost_loop.same_as(target_loop)) {
target_loop = split_res[0];
}
} else {
thread_extent = GetThreadIdxExtentFromTrace(tmp_sch->trace().value());
}
// Step 3.2. Do the compute-at.
tmp_sch->ComputeAt(block_rv, target_loop, /*preserve_unit_loops=*/true);
// Step 3.3. Set the storage scope of the output buffer to shared memory.
tmp_sch->SetScope(block_rv, /*buffer_index=*/0, /*storage_scope=*/"shared");
}

// Step 4. Reorder the loop axes if reduction loops are not innermost. After the reordering,
// fuse all the reduction loops.
size_t num_spatial_loops;
tir::LoopRV fused_reduce_loop;
ReorderAndFuseReductionLoops(tmp_sch, block_rv, &fused_reduce_loop, &num_spatial_loops);
// Step 5. Split the fused reduction loop and bind the inner one to threadIdx.
const Array<tir::LoopRV>& split_res =
tmp_sch->Split(fused_reduce_loop, {NullOpt, thread_extent});
tmp_sch->Bind(split_res[1], "threadIdx.x");

return {tmp_sch, sch};
}

private:
/*!
* \brief Check whether the input block is in thread scope, i.e., some of its outer loop is
* bound to threadIdx.
* \param sch The TensorIR schedule
* \param block The block to be checked
* \return A boolean indicating whether the block is in thread scope.
*/
bool InThreadScope(const tir::Schedule& sch, const tir::BlockRV& block) {
const Array<tir::LoopRV>& axes = sch->GetLoops(block);
for (const tir::LoopRV& loop_rv : axes) {
const tir::For& loop = sch->Get(loop_rv);
runtime::ThreadScope thread_scope = tir::GetThreadScope(loop.get());
if (tir::IsThreadIdx(thread_scope)) {
return true;
}
}
return false;
}

/*!
* \brief Get the ExprRV which used to define the extent of a given loop.
* \param trace The trace of the schedule, where the extent is to be found
* \param loop The loop whose extent is to be found
* \param extent The finding result
* \return Whether the find is successful.
*/
bool GetLoopRVExtentSource(const tir::Trace& trace, const tir::LoopRV& loop,
tir::ExprRV* extent) {
for (const tir::Instruction& inst : trace->insts) {
if (inst->kind->name == "Split") {
int i = std::find(inst->outputs.begin(), inst->outputs.end(), loop) - inst->outputs.begin();
CHECK(inst->inputs[1 + i].defined())
<< "ValueError: Extracting an extent which needs inference is not supported so far";
*extent = Downcast<tir::ExprRV>(inst->inputs[1 + i]);
return true;
}
}
return false;
}

/*!
* \brief Get the ExprRV extent of "threadIdx.x" in the given schedule trace.
* \param trace The trace of the schedule, where the extent is to be found
* \return The extent of "threadIdx.x" in the input schedule
*/
tir::ExprRV GetThreadIdxExtentFromTrace(const tir::Trace& trace) {
tir::ExprRV extent{nullptr};
for (const tir::Instruction& inst : trace->insts) {
if (inst->kind->name == "Bind" && Downcast<String>(inst->attrs[0]) == "threadIdx.x") {
if (GetLoopRVExtentSource(trace, Downcast<tir::LoopRV>(inst->inputs[0]), &extent)) {
return extent;
}
}
}
CHECK(false) << "ValueError: Unable to get the extent of \"threadIdx.x\"";
throw;
}

/*!
* \brief Get the compute-at target loop and the first block under the target loop.
* \param sch The TensorIR schedule
* \param block_rv The block whose compute-at target loop is queried
* \return A tuple consisting of
* 1. a boolean indicating whether the block can be computed at some target loop (a.k.a. fusible);
* 2. the compute-at target loop when fusible, or a null loop random variable;
* 3. the first block under the target loop when fusible, or a null block random variable;
* 4. the innermost loop outside the target block when fusible, or a null block random variable.
*/
std::tuple<bool, tir::LoopRV, tir::BlockRV, tir::LoopRV> GetComputeTargetLoopAndBlock(
const tir::Schedule& sch, const tir::BlockRV& block_rv) {
// Step 1. Get all the consumers of the input block.
Array<tir::BlockRV> consumers = sch->GetConsumers(block_rv);

// Step 2. If the block has no consumer or the first consumer needs multi-level tiling, it is
// not fusible.
if (consumers.empty() || tir::NeedsMultiLevelTiling(sch->state(), sch->GetSRef(consumers[0]))) {
return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr},
tir::LoopRV{nullptr});
}

// Step 3. Calculate the lowest common ancestor of all the consumers.
// - If the lowest common ancestor is a block:
// - if there is only one consumer, the target block is that consumer;
// - if there are multiple consumers, they must not share a common loop, and the case is not
// fusible;
// - If the lowest common ancestor is a loop, the target block is also the first consumer.
const tir::StmtSRef& lca_sref =
tir::GetSRefLowestCommonAncestor(tir::BlockRVs2StmtSRefs(sch, consumers));
if (consumers.size() > 1 && lca_sref->StmtAs<tir::BlockNode>() != nullptr) {
return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr},
tir::LoopRV{nullptr});
}

// Step 4. Get the outer loops of the target block, and get the compute-at position index.
Array<tir::LoopRV> tgt_block_loops = sch->GetLoops(consumers[0]);
int pos = GetComputePosition(sch, sch->GetLoops(block_rv), tgt_block_loops, lca_sref);

// Step 5. A negative position index means not fusible, and vice-versa.
if (pos < 0) {
return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr},
tir::LoopRV{nullptr});
} else {
return std::make_tuple(true, tgt_block_loops[pos], consumers[0], tgt_block_loops.back());
}
}

/*!
* \brief Get the compute-at position index of the input block, according to
* 1. the loops outside the input block;
* 2. the loops outside the target block;
* 3. the lowest common ancestor of all the consumers of the input block.
* \param sch The TensorIR schedule
* \param block_loops The loops outside the input block
* \param tgt_block_loops The loops outside the target block
* \param lca_sref The lowest common ancestor of all the consumers of the input block
* \return The compute-at position index of the input block
*/
int GetComputePosition(const tir::Schedule& sch, const Array<tir::LoopRV>& block_loops,
const Array<tir::LoopRV>& tgt_block_loops, const tir::StmtSRef& lca_sref) {
int n_block_loop = static_cast<int>(block_loops.size());
int n_tgt_block_loop = static_cast<int>(tgt_block_loops.size());

for (int i = 0; i < n_block_loop && i < n_tgt_block_loop; ++i) {
if (tir::GetLoopIterType(sch->GetSRef(block_loops[i])) != tir::IterVarType::kDataPar) {
return i - 1;
} else if (sch->GetSRef(tgt_block_loops[i]).same_as(lca_sref)) {
// If the lowest common ancestor is a loop, the compute location of the input block should
// not be deeper than the LCA loop.
return i;
}
}
return std::min(n_block_loop, n_tgt_block_loop) - 1;
}

public:
/*! \brief The maximum number of threads allowed in a thread block */
int max_threads_per_block;
/*! \brief The number of threads per warp */
int warp_size;
/*! \brief Candidates of thread axis extent (values are required to be positive). */
Array<Integer> thread_extents;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("max_threads_per_block", &max_threads_per_block);
v->Visit("warp_size", &warp_size);
v->Visit("thread_extents", &thread_extents);
}

static constexpr const char* _type_key = "meta_schedule.CrossThreadReduction";
TVM_DECLARE_FINAL_OBJECT_INFO(CrossThreadReductionNode, ScheduleRuleNode);
};

ScheduleRule ScheduleRule::CrossThreadReduction(Array<Integer> thread_extents) {
for (const Integer& extent : thread_extents) {
CHECK(extent->value > 0) << "ValueError: The candidates of thread extent must be positive";
}
ObjectPtr<CrossThreadReductionNode> n = make_object<CrossThreadReductionNode>();
n->thread_extents = std::move(thread_extents);
return ScheduleRule(n);
}

TVM_REGISTER_NODE_TYPE(CrossThreadReductionNode);
TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleCrossThreadReduction")
.set_body_typed(ScheduleRule::CrossThreadReduction);

} // namespace meta_schedule
} // namespace tvm
5 changes: 3 additions & 2 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1788,8 +1788,9 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, //
const StmtSRef& scope_sref = GetScopeRoot(self, block_sref, //
/*require_stage_pipeline=*/false, //
/*require_subtree_compact_dataflow=*/false);
if (!(IsReductionBlock(self, block_sref, scope_sref) && //
IsTrivialBinding(self, block_sref))) {
if (!IsReductionBlock(self, block_sref, scope_sref) //
|| !IsTrivialBinding(self, block_sref) //
|| HasBeenMultiLevelTiled(block_sref)) {
return false;
}

Expand Down
6 changes: 4 additions & 2 deletions src/tir/schedule/trace.cc
Original file line number Diff line number Diff line change
Expand Up @@ -448,8 +448,10 @@ Trace TraceNode::Simplified(bool remove_postproc) const {
}
// Add its inputs as "used" ones
for (const ObjectRef& obj : inst->inputs) {
if (obj->IsInstance<BlockRVNode>() || obj->IsInstance<LoopRVNode>() ||
obj->IsInstance<VarNode>()) {
if (!obj.defined()) {
continue;
} else if (obj->IsInstance<BlockRVNode>() || obj->IsInstance<LoopRVNode>() ||
obj->IsInstance<VarNode>()) {
used_rvs.insert(obj.get());
continue;
} else if (obj->IsInstance<PrimExprNode>()) {
Expand Down
Loading

0 comments on commit 1ac01b4

Please sign in to comment.