Skip to content

Commit

Permalink
Conditional Loop Partitioning - Extending to remove if conditions (ap…
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 authored and tqchen committed Oct 30, 2018
1 parent feca27e commit ea74668
Show file tree
Hide file tree
Showing 2 changed files with 191 additions and 14 deletions.
47 changes: 33 additions & 14 deletions src/pass/loop_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,11 +239,16 @@ class ThreadPartitionInserter : public IRMutator {
// Try to do partition at the candidate IRs
class LoopPartitioner : public IRMutator {
public:
explicit LoopPartitioner(std::unordered_set<const Node*> candidates)
: candidates_(candidates) {}
explicit LoopPartitioner(bool split_const_loop)
: selector(CandidateSelector(split_const_loop)) {}

Stmt VisitAndMutate(const Stmt& stmt) {
selector.Visit(stmt);
return Mutate(stmt);
}

Stmt Mutate_(const For* op, const Stmt& stmt) {
if (candidates_.count(op)) {
if (selector.candidates.count(op)) {
Stmt s = TryPartition(op, stmt, op->loop_var,
op->min, op->min + op->extent - 1, op->body, false);
if (s.defined()) return s;
Expand All @@ -266,7 +271,7 @@ class LoopPartitioner : public IRMutator {
const IterVarNode *iv = op->node.as<IterVarNode>();
CHECK(iv);
Var var = iv->var;
if (candidates_.count(op)) {
if (selector.candidates.count(op)) {
Stmt s = TryPartition(op, stmt, var, 0, op->value - 1, op->body, true);
if (s.defined()) return s;
}
Expand Down Expand Up @@ -295,9 +300,9 @@ class LoopPartitioner : public IRMutator {
inline Stmt MakeFor(const Node* op, Expr extent, Stmt body);

/* Candidate IRs that may be partitioned potentially */
std::unordered_set<const Node*> candidates_;
std::unordered_map<const Variable*, IntSet> hint_map_;
std::unordered_map<const Variable*, IntSet> relax_map_;
CandidateSelector selector;
};

Stmt LoopPartitioner::TryPartition(const Node* node,
Expand All @@ -322,7 +327,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
Expr body_begin;
Stmt pre_stmt;
if (true_itrv.as<arith::IntervalSet>()->i.has_lower_bound()) {
body_begin = true_itrv.min();
body_begin = ir::Simplify(true_itrv.min());
if (!can_prove(body_begin == min)) {
Expr cond = (body_begin - min >= 0);
if (!can_prove(cond)) {
Expand All @@ -343,7 +348,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
Expr post_doubt_begin;
Stmt post_stmt;
if (true_itrv.as<arith::IntervalSet>()->i.has_upper_bound()) {
post_doubt_begin = true_itrv.max() + 1;
post_doubt_begin = ir::Simplify(true_itrv.max() + 1);
if (!can_prove(true_itrv.max() == max)) {
// require the extent to be non-negative
Expr cond = (max - post_doubt_begin + 1 >= 0);
Expand All @@ -354,8 +359,17 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
}
// [post_doubt_begin, max]
if (!partition_thread_scope) {
Stmt post_body = Substitute(body, {{Var{var}, var + post_doubt_begin}});
post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
Stmt post_body;
// If the loop is going from 0 to 1, replace the loop var with min value
if (as_const_int(max) && as_const_int(post_doubt_begin)) {
if (*as_const_int(max) == *as_const_int(post_doubt_begin)) {
post_body = Substitute(body, {{Var{var}, post_doubt_begin}});
post_stmt = post_body;
}
} else {
post_body = Substitute(body, {{Var{var}, var + post_doubt_begin}});
post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
}
}
}
} else {
Expand All @@ -368,8 +382,15 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
Stmt simplified_body = ConditionEliminator(partitions).Mutate(body);
Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}});
s = MakeFor(node, post_doubt_begin - body_begin, new_body);
if (pre_stmt.defined()) s = Block::make(pre_stmt, s);
if (post_stmt.defined()) s = Block::make(s, post_stmt);

if (!(pre_stmt.defined() && post_stmt.defined())) s = VisitAndMutate(s);
if (pre_stmt.defined()) s = Block::make(pre_stmt, s);
if (post_stmt.defined()) {
if (as_const_int(max) && as_const_int(post_doubt_begin)) {
post_stmt = VisitAndMutate(post_stmt);
}
s = Block::make(s, post_stmt);
}
} else {
Expr cond = const_true();
if (!can_prove(body_begin == min)) cond = cond && (var >= body_begin);
Expand Down Expand Up @@ -402,9 +423,7 @@ class RemoveLikelyTags : public IRMutator {
};

Stmt LoopPartition(Stmt stmt, bool split_const_loop) {
CandidateSelector selector(split_const_loop);
selector.Visit(stmt);
stmt = LoopPartitioner(selector.candidates).Mutate(stmt);
stmt = LoopPartitioner(split_const_loop).VisitAndMutate(stmt);
stmt = RemoveLikelyTags().Mutate(stmt);
return stmt;
}
Expand Down
158 changes: 158 additions & 0 deletions tests/python/unittest/test_pass_loop_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,157 @@ def test_everything_during_deduction():
stmt = tvm.ir_pass.Simplify(stmt)
assert(isinstance(stmt.body.body, tvm.stmt.IfThenElse))

def test_single_likely():
n = 60
A = tvm.placeholder((n, ), name='A')
B = tvm.placeholder((n, ), name='B')

T = tvm.compute((n, ), lambda i: A[i]+B[i])
s = tvm.create_schedule(T.op)
x = T.op.axis[0]
xo, xi = s[T].split(x, factor=16)

bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.LoopPartition(stmt, True)
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))

def test_multi_likely():
n = 94
m = 62
A = tvm.placeholder((n, m), name='A')
B = tvm.placeholder((n, m), name='B')

T = tvm.compute((n, m), lambda i, j: A[i, j]+B[i, j])
s = tvm.create_schedule(T.op)
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
x, y = T.op.axis
xo, xi = s[T].split(x, factor=16)
yo, yi = s[T].split(y, factor=16)
s[T].reorder(xo, yo, xi, yi)

bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.LoopPartition(stmt, True)
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))

def test_oneD_pool():
m = tvm.var('m')
ib = tvm.ir_builder.create()
#data = tvm.placeholder((16,), name = 'data')
data = ib.pointer("float32", name="A")
out = ib.pointer("float32", name="A")
with ib.for_range(0, 16, 'ow') as ow:
with ib.for_range(0, 3, 'kw') as kw:
with ib.if_scope(ib.likely(ow > 0)):
with ib.if_scope(ib.likely(ow < 15)):
out[ow] = tvm.max(out[ow], data[ow + kw - 1])
with ib.for_range(0, 16, 'ow') as ow:
with ib.for_range(0, 3, 'kw') as kw:
with ib.if_scope(ib.likely(ow < 1)):
with ib.if_scope(ib.likely(kw > 0)):
out[ow] = tvm.max(out[ow], data[ow + kw - 1])
with ib.for_range(0, 16, 'ow') as ow:
with ib.for_range(0, 3, 'kw') as kw:
with ib.if_scope(ib.likely(ow > 14)):
with ib.if_scope(ib.likely(kw < 2)):
out[ow] = tvm.max(out[ow], data[ow + kw - 1])

stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt, True)
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))

def test_cce_loop_1():
ib = tvm.ir_builder.create()
dtype = 'float16'
n = 514
m = 514
_A = tvm.placeholder((n*m,), name = 'A')
Ab = tvm.decl_buffer((n*m,), dtype, name="A")
A = ib.buffer_ptr(Ab)
_B = tvm.placeholder((n*m,), name = 'B')
Bb = tvm.decl_buffer((n*m,), dtype, name="B")
B = ib.buffer_ptr(Bb)
#for i in 0 to n-1:
with ib.for_range(0, 11, name="i") as i:
with ib.for_range(0, 160, name="j") as j:
with ib.if_scope(ib.likely(((i*160) + j) < 1600)):
A[(i+1)*m+j+1] = B[(i)*m+j+1] + B[(i+1)*m+j+1] + B[(i+2)*m+j+1]
stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt, True)
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))

def test_cce_loop_2():
ib = tvm.ir_builder.create()
len = 112
tile = 32
loop = (len + tile - 1) // tile
with ib.for_range(0, loop, 'i') as i:
head = i * tile
with ib.if_scope(ib.likely(head + tile > len)):
tail = len
ib.emit(tvm.call_extern('float32', "cce_intrisic", head, tail))
with ib.else_scope():
tail = head + tile
ib.emit(tvm.call_extern('float32', "cce_intrisic", head, tail))

stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt, True)
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))


def test_cce_loop_3():
ib = tvm.ir_builder.create()
loop1 = 4
loop2 = 9998
tile = 39991
with ib.for_range(0,loop2,'i') as i:
with ib.for_range(0,loop1,'j') as j:
head1 = i
head2 = j
with ib.if_scope(ib.likely(head1*loop1 + head2 < tile)):
ib.emit(tvm.call_extern('float16',"cce_intrisic",head1))

stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt,True)
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))

def test_conv_tiling():
HSTR = WSTR = 1
in_channel = 128
kernel_height = kernel_width = 3
out_channel = 64
batch_size = 1
in_height = in_width = 64
out_height = out_width = in_height - kernel_height + 1
data = tvm.placeholder((batch_size, in_channel, in_height, in_width), name='data')
kernel = tvm.placeholder((kernel_height, kernel_width, in_channel,
out_channel), name='kernel')
ic = tvm.reduce_axis((0, in_channel), name='ic')
kh = tvm.reduce_axis((0, kernel_height), name='kh')
kw = tvm.reduce_axis((0, kernel_width), name='kw')
conv = tvm.compute((batch_size, out_channel, out_height, out_width),
lambda n, oc, oh, ow: tvm.sum(data[n, ic, oh*HSTR + kh, ow*WSTR + kw] *
kernel[kh, kw, ic, oc],
axis=[ic, kh, kw]),
name="conv2d")
s = tvm.create_schedule(conv.op)

n, oc, oh, ow = conv.op.axis
oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16)
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.LoopPartition(stmt, True)
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))

if __name__ == "__main__":
test_basic()
test_const_loop()
Expand All @@ -187,3 +338,10 @@ def test_everything_during_deduction():
test_select()
test_thread_axis2()
test_everything_during_deduction()
test_single_likely()
test_multi_likely()
test_oneD_pool()
test_cce_loop_1()
test_cce_loop_2()
test_cce_loop_3()
test_conv_tiling()

0 comments on commit ea74668

Please sign in to comment.