Skip to content

Commit

Permalink
reserve old compute at iter domain behavior with allow_block_predicat…
Browse files Browse the repository at this point in the history
…e=False
  • Loading branch information
wrongtest-intellif committed Dec 16, 2021
1 parent f00e619 commit f91c1fc
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 52 deletions.
8 changes: 5 additions & 3 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -355,9 +355,10 @@ class ScheduleNode : public runtime::Object {
* \param block_rv The block to be moved
* \param loop_rv The loop where the block to be moved under
* \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1
* \param allow_block_predicate Whether to use block predicate for block iteration bounds
*/
virtual void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
bool preserve_unit_loops) = 0;
virtual void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
bool allow_block_predicate) = 0;
/*!
* \brief Move a consumer block under the specific loop, and regenerate the
* loops induced by the block so that the buffer region consumed by the consumer block could
Expand All @@ -372,9 +373,10 @@ class ScheduleNode : public runtime::Object {
* \param block_rv The block to be moved
* \param loop_rv The loop where the block to be moved under
* \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1
* \param allow_block_predicate Whether to use block predicate for block iteration bounds
*/
virtual void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
bool preserve_unit_loops) = 0;
bool preserve_unit_loops, bool allow_block_predicate) = 0;
/*!
* \brief Inline a block into its consumer(s). It requires:
* 1) The block is a complete non-root block, which only produces one buffer
Expand Down
10 changes: 10 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,7 @@ def compute_at(
block: BlockRV,
loop: LoopRV,
preserve_unit_loops: bool = False,
allow_block_predicate: bool = True,
) -> None:
"""Compute-At. Move a producer block under the specific loop, and regenerate the
loops induced by the block so that the buffer region produced by the producer block could
Expand Down Expand Up @@ -1040,6 +1041,9 @@ def compute_at(
preserve_unit_loops: bool
Whether to keep the trivial loops whose extents are 1
allow_block_predicate: bool
Whether to use block predicate for block iteration bounds
Examples
--------
Expand Down Expand Up @@ -1096,13 +1100,15 @@ def after_compute_at(a: T.handle, c: T.handle) -> None:
block,
loop,
preserve_unit_loops,
allow_block_predicate,
)

def reverse_compute_at(
self,
block: BlockRV,
loop: LoopRV,
preserve_unit_loops: bool = False,
allow_block_predicate: bool = True,
) -> None:
"""Reverse-Compute-At. Move a consumer block under the specific loop, and regenerate the
loops induced by the block so that the buffer region consumed by the consumer block could
Expand All @@ -1129,6 +1135,9 @@ def reverse_compute_at(
preserve_unit_loops: bool
Whether to keep the trivial loops whose extents are 1
allow_block_predicate: bool
Whether to use block predicate for block iteration bounds
Examples
--------
Expand Down Expand Up @@ -1185,6 +1194,7 @@ def after_reverse_compute_at(a: T.handle, c: T.handle) -> None:
block,
loop,
preserve_unit_loops,
allow_block_predicate,
)

def compute_inline(self, block: BlockRV) -> None:
Expand Down
10 changes: 6 additions & 4 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buff
/******** Schedule: Compute location ********/

void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
bool preserve_unit_loops) {
bool preserve_unit_loops, bool allow_block_predicate) {
static StmtSRef inline_mark = StmtSRef::InlineMark();
static StmtSRef root_mark = StmtSRef::RootMark();
StmtSRef loop_sref = this->GetSRef(loop_rv);
Expand All @@ -492,14 +492,15 @@ void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop
TVM_TIR_SCHEDULE_END("compute-at", this->error_render_level_);
} else {
TVM_TIR_SCHEDULE_BEGIN();
tir::ComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops);
tir::ComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops,
allow_block_predicate);
TVM_TIR_SCHEDULE_END("compute-at", this->error_render_level_);
}
this->state_->DebugVerify();
}

void ConcreteScheduleNode::ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
bool preserve_unit_loops) {
bool preserve_unit_loops, bool allow_block_predicate) {
static StmtSRef inline_mark = StmtSRef::InlineMark();
static StmtSRef root_mark = StmtSRef::RootMark();
StmtSRef loop_sref = this->GetSRef(loop_rv);
Expand All @@ -511,7 +512,8 @@ void ConcreteScheduleNode::ReverseComputeAt(const BlockRV& block_rv, const LoopR
TVM_TIR_SCHEDULE_END("reverse-compute-at", this->error_render_level_);
} else {
TVM_TIR_SCHEDULE_BEGIN();
tir::ReverseComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops);
tir::ReverseComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops,
allow_block_predicate);
TVM_TIR_SCHEDULE_END("reverse-compute-at", this->error_render_level_);
}
this->state_->DebugVerify();
Expand Down
7 changes: 4 additions & 3 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,10 @@ class ConcreteScheduleNode : public ScheduleNode {
BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) override;
/******** Schedule: Compute location ********/
void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) override;
void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
bool preserve_unit_loops) override;
void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
bool allow_block_predicate) override;
void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
bool allow_block_predicate) override;
void ComputeInline(const BlockRV& block) override;
void ReverseComputeInline(const BlockRV& block) override;
/******** Schedule: Reduction ********/
Expand Down
7 changes: 5 additions & 2 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,10 @@ TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int
* \param block_sref The block to be moved
* \param loop_sref The loop where the block to be moved to
* \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1
* \param allow_block_predicate Whether to use block predicate for block iteration bounds
*/
TVM_DLL void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
bool preserve_unit_loops);
bool preserve_unit_loops, bool allow_block_predicate);
/*!
* \brief Move a consumer block under the specific loop, and regenerate the
* loops induced by the block so that the buffer region consumed by the consumer block could
Expand All @@ -260,9 +261,11 @@ TVM_DLL void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const Stm
* \param block_sref The block to be moved
* \param loop_sref The loop where the block to be moved to
* \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1
* \param allow_block_predicate Whether to use block predicate for iter bounds
*/
TVM_DLL void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref,
const StmtSRef& loop_sref, bool preserve_unit_loops);
const StmtSRef& loop_sref, bool preserve_unit_loops,
bool allow_block_predicate);
/*!
* \brief Inline a block into its consumer(s). It requires:
* 1) The block is a complete non-root block, which only produces one buffer
Expand Down
58 changes: 39 additions & 19 deletions src/tir/schedule/primitive/compute_at.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,11 @@ class ScopeReconstructor : private StmtMutator {
* \param iter_doms The domain of each block var
* \param analyzer The arithmetic analyzer
* \param preserve_unit_loops Whether to generate unit loops where the loop extent is 1
* \param allow_block_predicate Whether to use block predicate for iter bounds
*/
void MakeNewLoop(int insert_position, std::vector<BlockVarDomainInfo> iter_doms,
arith::Analyzer* analyzer, bool preserve_unit_loops) {
void MakeNewLoop(int insert_position, const std::vector<BlockVarDomainInfo>& iter_doms,
arith::Analyzer* analyzer, bool preserve_unit_loops,
bool allow_block_predicate) {
int n_iters = iter_doms.size();
Array<Var> loop_vars;
Array<PrimExpr> loop_extents;
Expand All @@ -236,14 +238,24 @@ class ScopeReconstructor : private StmtMutator {
iter_values.reserve(n_iters);
PrimExpr predicate = const_true();
for (int i = 0; i < n_iters; ++i) {
Range iter_dom = iter_doms[i].dom.CoverRange(block_->iter_vars[i]->dom);
const arith::IntSet& pred_bound = iter_doms[i].bound;
arith::IntSet pred_bound = iter_doms[i].bound;
arith::IntSet iter_dom_intset = iter_doms[i].dom;
if (!pred_bound.IsNothing() && !allow_block_predicate) {
// if block predicate not prefered, use dom^bound for generated loop range
iter_dom_intset = arith::Intersect({iter_dom_intset, pred_bound});
pred_bound = arith::IntSet::Nothing();
}
Range iter_dom = iter_dom_intset.CoverRange(block_->iter_vars[i]->dom);
if (preserve_unit_loops || !is_one(iter_dom->extent) || !pred_bound.IsNothing()) {
Var var("ax" + std::to_string(loop_vars.size()), DataType::Int(32));
loop_vars.push_back(var);
loop_extents.push_back(iter_dom->extent);
iter_values.push_back(iter_dom->min + var);
analyzer->Bind(var, Range::FromMinExtent(0, iter_dom->extent));
if (is_one(iter_dom->extent)) {
analyzer->Bind(var, 0);
} else {
analyzer->Bind(var, Range::FromMinExtent(0, iter_dom->extent));
}
if (pred_bound.HasLowerBound()) {
PrimExpr lower_bound = iter_dom->min + var >= pred_bound.min();
predicate = predicate && lower_bound;
Expand Down Expand Up @@ -542,7 +554,8 @@ void CalculateProvidedRequiredRegions(

template <bool is_compute_at>
void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_sref,
const StmtSRef& loop_sref, bool preserve_unit_loops) {
const StmtSRef& loop_sref, bool preserve_unit_loops,
bool allow_block_predicate) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
// Step 1. Bunch of checks
Expand Down Expand Up @@ -595,7 +608,8 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s
/*analyzer=*/&analyzer);
// Step 6. Create the new scope according to the iteration domain
reconstructor.MakeNewLoop(/*insert_position=*/insert_position, /*iter_doms=*/std::move(iter_doms),
/*analyzer=*/&analyzer, /*preserve_unit_loops=*/preserve_unit_loops);
/*analyzer=*/&analyzer, /*preserve_unit_loops=*/preserve_unit_loops,
/*allow_block_predicate=*/allow_block_predicate);
Block new_scope_root = Downcast<Block>(reconstructor(scope_root));
// Step 7. Do the actual replacement
self->Replace(scope_root_sref, new_scope_root, {{scope_root, new_scope_root}});
Expand All @@ -616,13 +630,15 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s
}

void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
bool preserve_unit_loops) {
ComputeAtOrReverseComputeAtImpl<true>(self, block_sref, loop_sref, preserve_unit_loops);
bool preserve_unit_loops, bool allow_block_predicate) {
ComputeAtOrReverseComputeAtImpl<true>(self, block_sref, loop_sref, preserve_unit_loops,
allow_block_predicate);
}

void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
bool preserve_unit_loops) {
ComputeAtOrReverseComputeAtImpl<false>(self, block_sref, loop_sref, preserve_unit_loops);
bool preserve_unit_loops, bool allow_block_predicate) {
ComputeAtOrReverseComputeAtImpl<false>(self, block_sref, loop_sref, preserve_unit_loops,
allow_block_predicate);
}

/******** InstructionKind Registration ********/
Expand All @@ -633,20 +649,22 @@ struct ComputeAtTraits : public UnpackedInstTraits<ComputeAtTraits> {

private:
static constexpr size_t kNumInputs = 2;
static constexpr size_t kNumAttrs = 1;
static constexpr size_t kNumAttrs = 2;
static constexpr size_t kNumDecisions = 0;

static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, LoopRV loop_rv,
Bool preserve_unit_loops) {
return sch->ComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool());
Bool preserve_unit_loops, Bool allow_block_predicate) {
return sch->ComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool(),
allow_block_predicate.operator bool());
}

static String UnpackedAsPython(Array<String> outputs, String block_rv, String loop_rv,
Bool preserve_unit_loops) {
Bool preserve_unit_loops, Bool allow_block_predicate) {
PythonAPICall py("compute_at");
py.Input("block", block_rv);
py.Input("loop", loop_rv);
py.Input("preserve_unit_loops", preserve_unit_loops.operator bool());
py.Input("allow_block_predicate", allow_block_predicate.operator bool());
return py.Str();
}

Expand All @@ -660,20 +678,22 @@ struct ReverseComputeAtTraits : public UnpackedInstTraits<ReverseComputeAtTraits

private:
static constexpr size_t kNumInputs = 2;
static constexpr size_t kNumAttrs = 1;
static constexpr size_t kNumAttrs = 2;
static constexpr size_t kNumDecisions = 0;

static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, LoopRV loop_rv,
Bool preserve_unit_loops) {
return sch->ReverseComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool());
Bool preserve_unit_loops, Bool allow_block_predicate) {
return sch->ReverseComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool(),
allow_block_predicate.operator bool());
}

static String UnpackedAsPython(Array<String> outputs, String block_rv, String loop_rv,
Bool preserve_unit_loops) {
Bool preserve_unit_loops, Bool allow_block_predicate) {
PythonAPICall py("reverse_compute_at");
py.Input("block", block_rv);
py.Input("loop", loop_rv);
py.Input("preserve_unit_loops", preserve_unit_loops.operator bool());
py.Input("allow_block_predicate", allow_block_predicate.operator bool());
return py.Str();
}

Expand Down
27 changes: 15 additions & 12 deletions src/tir/schedule/traced_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,25 +253,28 @@ BlockRV TracedScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer
/******** Schedule: Compute location ********/

void TracedScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
bool preserve_unit_loops) {
ConcreteScheduleNode::ComputeAt(block_rv, loop_rv, preserve_unit_loops);
bool preserve_unit_loops, bool allow_block_predicate) {
ConcreteScheduleNode::ComputeAt(block_rv, loop_rv, preserve_unit_loops, allow_block_predicate);

static const InstructionKind& kind = InstructionKind::Get("ComputeAt");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
/*inputs=*/{block_rv, loop_rv},
/*attrs=*/{Integer(preserve_unit_loops)},
/*outputs=*/{}));
trace_->Append(
/*inst=*/Instruction(/*kind=*/kind,
/*inputs=*/{block_rv, loop_rv},
/*attrs=*/{Integer(preserve_unit_loops), Integer(allow_block_predicate)},
/*outputs=*/{}));
}

void TracedScheduleNode::ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
bool preserve_unit_loops) {
ConcreteScheduleNode::ReverseComputeAt(block_rv, loop_rv, preserve_unit_loops);
bool preserve_unit_loops, bool allow_block_predicate) {
ConcreteScheduleNode::ReverseComputeAt(block_rv, loop_rv, preserve_unit_loops,
allow_block_predicate);

static const InstructionKind& kind = InstructionKind::Get("ReverseComputeAt");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
/*inputs=*/{block_rv, loop_rv},
/*attrs=*/{Integer(preserve_unit_loops)},
/*outputs=*/{}));
trace_->Append(
/*inst=*/Instruction(/*kind=*/kind,
/*inputs=*/{block_rv, loop_rv},
/*attrs=*/{Integer(preserve_unit_loops), Integer(allow_block_predicate)},
/*outputs=*/{}));
}

void TracedScheduleNode::ComputeInline(const BlockRV& block_rv) {
Expand Down
7 changes: 4 additions & 3 deletions src/tir/schedule/traced_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,10 @@ class TracedScheduleNode : public ConcreteScheduleNode {
BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) final;
/******** Schedule: Compute location ********/
void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) final;
void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
bool preserve_unit_loops) final;
void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
bool allow_block_predicate) final;
void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
bool allow_block_predicate) final;
void ComputeInline(const BlockRV& block_rv) final;
void ReverseComputeInline(const BlockRV& block_rv) final;
/******** Schedule: Reduction ********/
Expand Down
Loading

0 comments on commit f91c1fc

Please sign in to comment.