Skip to content

Commit

Permalink
resolve unstaged changes in inject_software_pipeline.cc
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang1999 committed Mar 25, 2023
1 parent dcc5417 commit 35a2d3f
Showing 1 changed file with 11 additions and 22 deletions.
33 changes: 11 additions & 22 deletions src/tir/transforms/inject_software_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -309,10 +309,9 @@ class PipelineRewriter : public StmtExprMutator {
const Array<Buffer> pipeline_allocs, const For& pipeline_loop,
const PipelineInfo& pipeline_info,
const std::unordered_map<const VarNode*, FragmentInfo>& fragment_info,
const Map<String, ObjectRef> preserved_annotations, bool merge_async_commit_queue_scope) {
const Map<String, ObjectRef> preserved_annotations) {
PipelineRewriter rewriter(buffer_data_to_buffer, double_buffers, pipeline_allocs, pipeline_loop,
pipeline_info, fragment_info, preserved_annotations,
merge_async_commit_queue_scope);
pipeline_info, fragment_info, preserved_annotations);
return rewriter.BuildPipeline();
}

Expand All @@ -322,17 +321,15 @@ class PipelineRewriter : public StmtExprMutator {
const Array<Buffer>& pipeline_allocs, const For& pipeline_loop,
const PipelineInfo& pipeline_info,
const std::unordered_map<const VarNode*, FragmentInfo>& fragment_info,
const Map<String, ObjectRef> preserved_annotations,
bool merge_async_commit_queue_scope)
const Map<String, ObjectRef> preserved_annotations)

: buffer_data_to_buffer_(std::move(buffer_data_to_buffer)),
double_buffers_(double_buffers),
pipeline_allocs_(pipeline_allocs),
pipeline_loop_(pipeline_loop),
pipeline_info_(pipeline_info),
fragment_info_(fragment_info),
preserved_annotations_(preserved_annotations),
merge_async_commit_queue_scope_(merge_async_commit_queue_scope) {}
preserved_annotations_(preserved_annotations) {}

Stmt BuildPipeline() {
// Step 1: Analyze accesses to the buffers in the pipeline and compute the number of versions
Expand Down Expand Up @@ -766,7 +763,7 @@ class PipelineRewriter : public StmtExprMutator {
group_bodies.push_back(new_blocks[i].block->body);
}

if (merge_async_commit_queue_scope_ && group_bodies.size() > 1) {
if (group_bodies.size() > 1) {
auto merged_bodies = SeqStmt(group_bodies);
group_bodies.clear();
group_bodies.push_back(merged_bodies);
Expand Down Expand Up @@ -853,8 +850,7 @@ class PipelineRewriter : public StmtExprMutator {
auto& local_state = async_states_local[stage];

int commit_group_id = -1;
if (local_state.commit_groups.empty() || local_state.consumed ||
!merge_async_commit_queue_scope_) {
if (local_state.commit_groups.empty() || local_state.consumed) {
// consumed == true means there is already a consumer stage waiting for an
// eariler async operation of this stage. In such cases, we make multiple commit_queue
// for this stage.
Expand Down Expand Up @@ -954,7 +950,6 @@ class PipelineRewriter : public StmtExprMutator {
Array<Block> ordered_stmts_;
std::map<int, AsyncStateGlobal> async_states;
Map<String, ObjectRef> preserved_annotations_;
bool merge_async_commit_queue_scope_ = true;
};

/*!
Expand Down Expand Up @@ -1146,9 +1141,9 @@ class PipelineInjector : private StmtExprMutator {
ValidatePipelineBody(pipeline_info, original_order);

// Step 4: Rewrite the pipeline body.
Stmt pipeline = PipelineRewriter::Rewrite(
buffer_data_to_buffer_, double_buffers, pipeline_allocs, GetRef<For>(op), pipeline_info,
fragment_info_, preserved_annotations, merge_async_commit_queue_scope_);
Stmt pipeline = PipelineRewriter::Rewrite(buffer_data_to_buffer_, double_buffers,
pipeline_allocs, GetRef<For>(op), pipeline_info,
fragment_info_, preserved_annotations);

if (const auto* realize = op->body.as<BlockRealizeNode>()) {
const auto& block = realize->block;
Expand Down Expand Up @@ -1217,11 +1212,7 @@ class PipelineInjector : private StmtExprMutator {
Map<Var, Buffer> buffer_data_to_buffer_;
std::unordered_map<const VarNode*, FragmentInfo> fragment_info_;
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> double_buffers;
<<<<<<< HEAD
bool merge_async_commit_queue_scope_ = true;
=======
Optional<String> global_symbol_;
>>>>>>> 0d0d2f0bd33667316a255212e89a408a5f541817
};

} // namespace software_pipeline
Expand All @@ -1235,9 +1226,7 @@ namespace transform {
Pass InjectSoftwarePipeline() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* fptr = f.CopyOnWrite();
bool merge_async_commit_queue_scope =
ctx->GetConfig<Bool>("tir.merge_async_commit_queue_scope", Bool(true)).value();
fptr->body = software_pipeline::PipelineInjector::Inject(f, merge_async_commit_queue_scope);
fptr->body = software_pipeline::PipelineInjector::Inject(f);
fptr->body = ConvertSSA(std::move(fptr->body));
return f;
};
Expand All @@ -1249,4 +1238,4 @@ TVM_REGISTER_GLOBAL("tir.transform.InjectSoftwarePipeline").set_body_typed(Injec
} // namespace transform

} // namespace tir
} // namespace tvm
} // namespace tvm

0 comments on commit 35a2d3f

Please sign in to comment.