Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Enhance software pipeline validation and fix predicate of epilogue #11106

Merged
merged 4 commits into from
Apr 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 65 additions & 9 deletions src/tir/transforms/inject_software_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,10 @@ class PipelineRewriter : public StmtExprMutator {
subst_map.Set(pipeline_loop_->loop_var, skewed_loop_var);
} else {
// normalize loop range
subst_map.Set(pipeline_loop_->loop_var, skewed_loop_var + (start - pipeline_loop_->min));
PrimExpr delta = start - pipeline_loop_->min;
subst_map.Set(pipeline_loop_->loop_var, skewed_loop_var + delta);
Var loop_iter = Downcast<Var>(new_loop_var);
inbound = Substitute(inbound, Map<Var, PrimExpr>{{loop_iter, loop_iter + delta}});
}
new_block = Downcast<Block>(Substitute(new_block, subst_map));
stmts.push_back(BlockRealize({}, inbound, new_block));
Expand Down Expand Up @@ -570,6 +573,40 @@ class PipelineRewriter : public StmtExprMutator {
Array<Block> ordered_stmts_;
};

/*!
* \brief Build the dependency graph among a array of blocks.
* \param[in] blocks The array of blocks.
* \param[out] dep_src2dst Optional, a map to store dependency edges from the source to the
* destination.
* \param[out] dep_dst2src Optional, a map to store dependency edges from the
* destination to the source.
*/
void BuildDependencyGraph(
const Array<Block>& blocks,
std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual>* dep_src2dst,
std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual>* dep_dst2src) {
std::unordered_map<Var, Array<Block>, ObjectPtrHash, ObjectPtrEqual> buffer_writers;

for (const Block& block : blocks) {
for (const BufferRegion& read : block->reads) {
auto it = buffer_writers.find(read->buffer->data);
if (it != buffer_writers.end()) {
for (const Block& writer : it->second) {
if (dep_src2dst != nullptr) {
(*dep_src2dst)[writer].push_back(block);
}
if (dep_dst2src != nullptr) {
(*dep_dst2src)[block].push_back(writer);
}
}
}
}
for (const BufferRegion& write : block->writes) {
buffer_writers[write->buffer->data].push_back(block);
}
}
}

class PipelineInjector : private StmtExprMutator {
public:
static Stmt Inject(const PrimFunc& func) {
Expand All @@ -587,24 +624,43 @@ class PipelineInjector : private StmtExprMutator {

/*!
* \brief Check the pipeline satisfies the following conditions:
* 1) No conflicting order: The order of each statement should be unique.
* 2) No reordering with the same stage: Statements in the same stage are not allowed to be
* reordered.
* 1. No conflicting order: The order of each statement should be unique.
* 2. Reordering of statements doesn't break buffer access dependencies. Specifically, for
* dependency (e.g. read-after-write) from statement A to statement B, it requires:
* case 1: stage(A) < stage(B)
* case 2: stage(A) == stage(B) and order(A) < order(B)
*/
void ValidatePipelineBody(const PipelineInfo& pipeline_info, const Array<Block>& original_order) {
std::unordered_set<int> used_orders;
std::unordered_map<int, int> stage_max_order;
std::unordered_map<int, const Block*> order_to_block;
std::unordered_map<const Block*, int> block_to_stage;
for (const Block& block : original_order) {
const auto& stmt_info = pipeline_info.at(block);
int stage = stmt_info.stage;
int order = stmt_info.order;
CHECK(!used_orders.count(order))
<< "ValueError: Two statements in the software pipeline cannot have the same order";
used_orders.insert(order);
CHECK(!stage_max_order.count(stage) || stage_max_order[stage] < order)
<< "ValueError: Statements in the same stage of the software pipeline must have "
"increasing order.";
stage_max_order[stage] = order;
}

std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual> dep_src2dst;
BuildDependencyGraph(original_order, &dep_src2dst, nullptr);

for (const auto& pair : dep_src2dst) {
const Block& src = pair.first;
const auto& src_info = pipeline_info.at(src);
const Array<Block>& dsts = pair.second;
for (const Block& dst : dsts) {
const auto& dst_info = pipeline_info.at(dst);
CHECK_LE(src_info.stage, dst_info.stage)
<< "ValueError: statement " << dst << " in stage " << dst_info.stage
<< " cannot depends on statement " << src << " in a later stage " << src_info.stage;
if (src_info.stage == dst_info.stage) {
CHECK_LT(src_info.order, dst_info.order) << "ValueError: two statements with buffer "
"access dependency in the same stage of the "
"software pipeline cannot be reordered";
}
}
}
}

Expand Down
201 changes: 201 additions & 0 deletions tests/python/unittest/test_tir_transform_inject_software_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,199 @@ def transformed_simple_compute(
C[tx, 15] = B[1, tx, 0] + T.float32(1)


@T.prim_func
def three_stage_compute(A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"]):
for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
for i in T.serial(
0,
16,
annotations={
"software_pipeline_stage": [0, 1, 2],
"software_pipeline_order": [0, 1, 2],
},
):
with T.block():
T.reads(A[tx, i])
T.writes(D[tx, i])
B = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
C = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
with T.block():
T.reads(A[tx, i])
T.writes(B[tx, 0])
B[tx, 0] = A[tx, i] * T.float32(2)
with T.block():
T.reads(B[tx, 0])
T.writes(C[tx, 0])
C[tx, 0] = A[tx, 0] + T.float32(2)
with T.block():
T.reads(C[tx, 0])
T.writes(D[tx, i])
D[tx, i] = C[tx, 0] + T.float32(1)


@T.prim_func
def transformed_three_stage_compute(
A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"]
) -> None:
for tx in T.thread_binding(16, thread="threadIdx.x"):
with T.block():
T.reads(A[tx, 0:16])
T.writes(D[tx, 0:16])
B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared")
C = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared")
with T.block():
T.reads(A[tx, 0:2], B[0:2, tx, 0])
T.writes(B[0:2, tx, 0], C[0:2, tx, 0])
for i in T.unroll(2):
with T.block():
T.reads(A[tx, i])
T.writes(B[0:2, tx, 0])
B[i, tx, 0] = A[tx, i] * T.float32(2)
with T.block():
T.where(1 <= i)
T.reads(B[0:2, tx, 0])
T.writes(C[0:2, tx, 0])
C[(i + 1) % 2, tx, 0] = A[tx, 0] + T.float32(2)
with T.block():
T.reads(A[tx, 2:16], B[0:2, tx, 0], C[0:2, tx, 0])
T.writes(B[0:2, tx, 0], C[0:2, tx, 0], D[tx, 0:14])
for i in T.serial(14):
with T.block():
T.reads(A[tx, i + 2])
T.writes(B[0:2, tx, 0])
B[i % 2, tx, 0] = A[tx, i + 2] * T.float32(2)
with T.block():
T.reads(B[0:2, tx, 0])
T.writes(C[0:2, tx, 0])
C[(i + 1) % 2, tx, 0] = A[tx, 0] + T.float32(2)
with T.block():
T.reads(C[0:2, tx, 0])
T.writes(D[tx, i])
D[tx, i] = C[i % 2, tx, 0] + T.float32(1)
with T.block():
T.reads(B[0:2, tx, 0], C[0:2, tx, 0])
T.writes(C[0:2, tx, 0], D[tx, 14:16])
for i in T.unroll(2):
with T.block():
T.where(i < 1)
T.reads(B[0:2, tx, 0])
T.writes(C[0:2, tx, 0])
C[(i + 1) % 2, tx, 0] = A[tx, 0] + T.float32(2)
with T.block():
T.reads(C[0:2, tx, 0])
T.writes(D[tx, i + 14])
D[tx, i + 14] = C[i, tx, 0] + T.float32(1)


@T.prim_func
def dag_interleaving(
A: T.Buffer[(16, 16), "float32"],
B: T.Buffer[(16, 16), "float32"],
C: T.Buffer[(16, 16), "float32"],
) -> None:
for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
for i in T.serial(
0,
16,
annotations={
"software_pipeline_stage": [0, 0, 0, 0, 1],
"software_pipeline_order": [0, 2, 1, 3, 4],
},
):
with T.block():
T.reads(A[tx, i])
T.writes(C[tx, i])
AS = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
BS = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
AL = T.alloc_buffer((1, 1), dtype="float32", scope="local")
BL = T.alloc_buffer((1, 1), dtype="float32", scope="local")
with T.block():
T.reads(A[tx, i])
T.writes(AS[tx, 0])
AS[tx, 0] = A[tx, i] * T.float32(2)
with T.block():
T.reads(AS[tx, 0])
T.writes(AL[0, 0])
AL[0, 0] = AS[tx, 0]
with T.block():
T.reads(B[tx, i])
T.writes(BS[tx, 0])
BS[tx, 0] = B[tx, i] + T.float32(2)
with T.block():
T.reads(BS[tx, 0])
T.writes(BL[0, 0])
BL[0, 0] = BS[tx, 0]
with T.block():
T.reads(AL[0, 0], BL[0, 0])
T.writes(C[tx, i])
C[tx, i] = AL[0, 0] * BL[0, 0]


@T.prim_func
def transformed_dag_interleaving(
A: T.Buffer[(16, 16), "float32"],
B: T.Buffer[(16, 16), "float32"],
C: T.Buffer[(16, 16), "float32"],
) -> None:
for tx in T.thread_binding(16, thread="threadIdx.x"):
with T.block():
T.reads(A[tx, 0:16], B[tx, 0:16])
T.writes(C[tx, 0:16])
AS = T.alloc_buffer([16, 1], dtype="float32", scope="shared")
BS = T.alloc_buffer([16, 1], dtype="float32", scope="shared")
AL = T.alloc_buffer([2, 1, 1], dtype="float32", scope="local")
BL = T.alloc_buffer([2, 1, 1], dtype="float32", scope="local")
with T.block():
T.reads(A[tx, 0], B[tx, 0], AS[tx, 0], BS[tx, 0])
T.writes(AS[tx, 0], BS[tx, 0], AL[0, 0, 0], BL[0, 0, 0])
with T.block():
T.reads(A[tx, 0])
T.writes(AS[tx, 0])
AS[tx, 0] = A[tx, 0] * T.float32(2)
with T.block():
T.reads(B[tx, 0])
T.writes(BS[tx, 0])
BS[tx, 0] = B[tx, 0] + T.float32(2)
with T.block():
T.reads(AS[tx, 0])
T.writes(AL[0, 0, 0])
AL[0, 0, 0] = AS[tx, 0]
with T.block():
T.reads(BS[tx, 0])
T.writes(BL[0, 0, 0])
BL[0, 0, 0] = BS[tx, 0]
with T.block():
T.reads(
A[tx, 1:16], B[tx, 1:16], AS[tx, 0], BS[tx, 0], AL[0:2, 0, 0], BL[0:2, 0, 0]
)
T.writes(AS[tx, 0], BS[tx, 0], AL[0:2, 0, 0], BL[0:2, 0, 0], C[tx, 0:15])
for i in T.serial(15):
with T.block():
T.reads(A[tx, i + 1])
T.writes(AS[tx, 0])
AS[tx, 0] = A[tx, i + 1] * T.float32(2)
with T.block():
T.reads(B[tx, i + 1])
T.writes(BS[tx, 0])
BS[tx, 0] = B[tx, i + 1] + T.float32(2)
with T.block():
T.reads(AS[tx, 0])
T.writes(AL[(i + 1) % 2, 0, 0])
AL[(i + 1) % 2, 0, 0] = AS[tx, 0]
with T.block():
T.reads(BS[tx, 0])
T.writes(BL[(i + 1) % 2, 0, 0])
BL[(i + 1) % 2, 0, 0] = BS[tx, 0]
with T.block():
T.reads(AL[i % 2, 0, 0], BL[i % 2, 0, 0])
T.writes(C[tx, i])
C[tx, i] = AL[i % 2, 0, 0] * BL[i % 2, 0, 0]
with T.block():
T.reads(AL[1, 0, 0], BL[1, 0, 0])
T.writes(C[tx, 15])
C[tx, 15] = AL[1, 0, 0] * BL[1, 0, 0]


@T.prim_func
def nested_pipeline_simple(
A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]
Expand Down Expand Up @@ -792,6 +985,14 @@ def test_trivial_pipeline():
_check(trivial_pipeline, transformed_trivial_pipeline)


def test_three_stage_compute():
_check(three_stage_compute, transformed_three_stage_compute)


def test_dag_interleaving():
_check(dag_interleaving, transformed_dag_interleaving)


def test_nest_pipeline_simple():
_check(nested_pipeline_simple, transformed_nested_pipeline_simple)

Expand Down