Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN/Fusion] horizontal support dynamic shape and enhance fusion ability #63913

Merged
merged 26 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ void ApplyDivideGroupOpToFusionOpPass(
std::shared_ptr<pir::PassManager> pass_manager = CreatePassManager();
if (FLAGS_group_schedule_tiling_first) {
pass_manager->AddPass(cinn::dialect::ir::CreateCinnGroupClusterPass());
pass_manager->AddPass(cinn::dialect::ir::CreateAddStoreInFusionOpPass());
// pass_manager->AddPass(cinn::dialect::ir::CreateAddStoreInFusionOpPass());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete this
可以单独提一个pr处理

} else {
pass_manager->AddPass(
cinn::dialect::ir::CreateDivideGroupOpToFusionOpPass());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ std::vector<GroupClusterNode> GroupSplit(cinn::dialect::GroupOp group_op) {
std::function<cinn::fusion::FrontendContent(pir::Operation*)> func =
[](pir::Operation* op) { return cinn::fusion::FrontendContent(op); };
const auto& contents = cinn::fusion::MapVector(group_op.GetOperators(), func);
auto cluster_result = cinn::fusion::ClusterOps(contents);
auto cluster_result = cinn::fusion::ClusterOps(contents, {});
std::vector<std::vector<pir::Operation*>> result;
std::transform(
cluster_result.begin(),
Expand Down Expand Up @@ -390,6 +390,9 @@ class CinnGroupClusterPass : public pir::PatternRewritePass {
}

bool CanApplyOn(pir::Operation* op) const override {
if (op->isa<cinn::dialect::FusionOp>()) {
return false;
}
return op->num_regions() > 0;
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,16 @@ void SetLeafBlockByGroupView(
value_dim_exprs_list,
value_to_dim_expr_idx);

// We should update the GlobalShapeAnalysisCache after the group is cloned.
for (const auto& op : new_group->ops()) {
for (const auto& v : op->results()) {
auto* shape_analysis =
&pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram());
shape_analysis->SetShapeOrDataForValue(v,
new_group->GetShapeOrDataExprs(v));
2742195759 marked this conversation as resolved.
Show resolved Hide resolved
}
}

// Insert YieldOp for outputs
std::vector<pir::Value> outputs;
builder.SetInsertionPointToBlockEnd(block);
Expand Down
5 changes: 3 additions & 2 deletions paddle/cinn/hlir/framework/pir/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ BucketLoweredFuncsWrapper OpLowererImpl::BucketLower(

// =========== OpFusion ============

func_bodies = OperationFusion(ops, func_bodies);
// VLOG(4) << "Bucket Lower output values is : " << group->output_values();
func_bodies = OperationFusion(ops, func_bodies, group->output_values());
2742195759 marked this conversation as resolved.
Show resolved Hide resolved
const auto& fusion_group_info = GetFusionGroupInfo(func_bodies);

// =========== CodeGen And Optimizer ================
Expand Down Expand Up @@ -728,7 +729,7 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess(
group->mut_output_names().clear();

// collect all output tensor.
for (auto op_result : group->GetGroupOutputValues()) {
for (auto op_result : group->output_values()) {
if (tensor_map.count(op_result) == 0) {
continue;
}
Expand Down
66 changes: 47 additions & 19 deletions paddle/cinn/hlir/framework/pir/trivial_op_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,17 +187,21 @@ std::vector<ir::Var> GetOutputIters(const FusibleOp& op) {
return AppendBound(std::visit(Visitor(), op), _GetRootExpr(op));
}

std::vector<ir::Var> GetAllIterVars(const ir::Expr& expr) {
ir::Expr compute_schedule_block_realize =
(ExprSetFinderUtils::ChildScheduleBlockRealizes *
ExprSetFinderUtils::ScheduleBlockRealizeIsNotInit)
.GetSingle(expr);

const std::vector<Expr>& all_iter_expr =
compute_schedule_block_realize.As<ir::ScheduleBlockRealize>()
->iter_values;
return ComposeUtils::ExprVec2VarVec(all_iter_expr);
}

std::vector<ir::Var> GetReduceIters(const ReduceOp& op) {
auto GetUnorderedAllIterVars = [](const ReduceOp& op) {
ir::Expr compute_schedule_block_realize =
(ExprSetFinderUtils::ChildScheduleBlockRealizes *
ExprSetFinderUtils::ScheduleBlockRealizeIsNotInit)
.GetSingle(_GetRootExpr(op));

const std::vector<Expr>& all_iter_expr =
compute_schedule_block_realize.As<ir::ScheduleBlockRealize>()
->iter_values;
return ComposeUtils::ExprVec2VarVec(all_iter_expr);
return GetAllIterVars(_GetRootExpr(op));
};

// Iter Vars not appearing in outer_iter_vars are pushed into
Expand Down Expand Up @@ -560,16 +564,38 @@ std::pair<TrivialOp, ReduceOp> SplitReduceOp(const ReduceOp& reduce_op) {
return std::make_pair(result_trivial, result_reduce);
}

std::vector<ir::Var> GetAllForIters(const ir::Expr& expr) {
using cinn::hlir::framework::pir::trivial_fusion_detail::ExprSetFinderUtils::
ChildFors;
using cinn::hlir::framework::pir::trivial_fusion_detail::ExprSetFinderUtils::
ChildScheduleBlockRealizes;
using cinn::hlir::framework::pir::trivial_fusion_detail::ExprSetFinderUtils::
FindFather;
using cinn::hlir::framework::pir::trivial_fusion_detail::ExprSetFinderUtils::
IsFor;
using cinn::hlir::framework::pir::trivial_fusion_detail::ExprSetFinderUtils::
ScheduleBlockRealizeIsNotInit;
const auto& all_father_fors =
(ChildScheduleBlockRealizes * ScheduleBlockRealizeIsNotInit *
FindFather(expr) * IsFor)(expr);
std::vector<ir::Var> vars;
for (const auto& for_expr : all_father_fors) {
vars.push_back(for_expr.As<ir::For>()->loop_var);
}
VLOG(4) << "GetAllForIters : " << expr
<< "\n var is : " << utils::Join(vars, ",");
return vars;
}

} // namespace trivial_fusion_detail

std::vector<ir::Expr> OperationFusion(
const std::vector<::pir::Operation*>& original_ops,
const std::vector<ir::Expr>& op_compute_bodies) {
PADDLE_ENFORCE_EQ(FLAGS_group_schedule_tiling_first,
true,
::common::errors::PreconditionNotMet(
"TrivialFusion must be used with tiling first, set "
"FLAGS_group_schedule_tiling_first=1"));
const std::vector<ir::Expr>& op_compute_bodies,
const std::vector<::pir::Value>& outputs) {
CHECK(FLAGS_group_schedule_tiling_first)
2742195759 marked this conversation as resolved.
Show resolved Hide resolved
<< "TrivialFusion must be used with tiling first, set "
"FLAGS_group_schedule_tiling_first=1";
const auto& ops = trivial_fusion_detail::FilterVector(
original_ops, [](const ::pir::Operation* op) {
if (op->name() == "cinn_op.generate_shape") {
Expand All @@ -581,10 +607,9 @@ std::vector<ir::Expr> OperationFusion(
std::vector<cinn::fusion::BackendContent> contents;
for (int i = 0; i < ops.size(); i++) {
contents.emplace_back(ops[i], op_compute_bodies[i]);
// contents.emplace_back(ops[i]);
}
const auto& fusion_nodes =
cinn::fusion::ClusterOps<cinn::fusion::BackendStage>(contents);
cinn::fusion::ClusterOps<cinn::fusion::BackendStage>(contents, outputs);

PADDLE_ENFORCE_EQ(fusion_nodes.size(),
1,
Expand All @@ -601,6 +626,8 @@ std::vector<ir::Expr> OperationFusion(

FusionGroupInfo GetFusionGroupInfo(
const std::vector<ir::Expr>& op_compute_bodies) {
using trivial_fusion_detail::AppendBound;
using trivial_fusion_detail::GetAllForIters;
using trivial_fusion_detail::ReduceOp;
using trivial_fusion_detail::ComposeUtils::ConcatVector;
using trivial_fusion_detail::ExprSetFinderUtils::ChildScheduleBlockRealizes;
Expand All @@ -618,7 +645,7 @@ FusionGroupInfo GetFusionGroupInfo(
ReduceOp op = ReduceOp(body);
if (group_info.reduce_var_name.empty()) {
std::vector<ir::Var> all_iters =
ConcatVector(GetOutputIters(op), GetReduceIters(op));
AppendBound(GetAllForIters(body), body);
std::transform(all_iters.begin(),
all_iters.end(),
std::back_inserter(group_info.loop_ranges),
Expand All @@ -631,7 +658,8 @@ FusionGroupInfo GetFusionGroupInfo(
return (int64_t)-1;
}
});
std::vector<ir::Var> reduce_iters = GetReduceIters(op);
std::vector<ir::Var> reduce_iters = fusion::FilterVector(
all_iters, [](const ir::Var& var) { return var->is_reduce_axis; });
for (int64_t i = all_iters.size() - reduce_iters.size();
i < all_iters.size();
i++) {
Expand Down
7 changes: 6 additions & 1 deletion paddle/cinn/hlir/framework/pir/trivial_op_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ FusibleOp SinkTrivialLoopAlign(TrivialOp trivial_op,
ReduceOp reduce_op,
std::vector<size_t> fake_reduce_iter_idx);

std::vector<ir::Var> GetAllIterVars(const ir::Expr& expr);

std::vector<ir::Var> GetAllForIters(const ir::Expr& expr);

} // namespace trivial_fusion_detail

struct FusionGroupInfo {
Expand All @@ -178,7 +182,8 @@ FusionGroupInfo GetFusionGroupInfo(

std::vector<ir::Expr> OperationFusion(
const std::vector<::pir::Operation*>& ops,
const std::vector<ir::Expr>& op_compute_bodies);
const std::vector<ir::Expr>& op_compute_bodies,
const std::vector<::pir::Value>& outputs);

} // namespace pir
} // namespace framework
Expand Down
76 changes: 70 additions & 6 deletions paddle/cinn/hlir/framework/pir/trivial_op_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ std::vector<ir::Expr> VarVec2ExprVec(const std::vector<ir::Var>& in) {
std::vector<ir::Expr> GetEachTensorLoadExpr(const ir::Expr& body,
const ir::Tensor& tensor) {
VLOG(4) << "GetEachTensorLoadExpr: " << tensor;
std::set<Expr> load_exprs = cinn::ir::ir_utils::CollectIRNodesWithoutTensor(
std::vector<Expr> load_exprs = cinn::ir::ir_utils::CollectIRNodesInOrder(
body, [&tensor](const Expr* expr) {
return expr->As<ir::Load>() && expr->As<ir::Load>()->is_addr_tensor() &&
expr->As<ir::Load>()->tensor.as_tensor_ref()->name ==
Expand All @@ -65,7 +65,7 @@ std::vector<ir::Expr> GetEachTensorLoadExpr(const ir::Expr& body,
for (auto& t : load_exprs) {
VLOG(4) << "GetEachTensorLoadExpr Found: " << t << " " << t.ptr();
}
return std::vector(load_exprs.begin(), load_exprs.end());
return load_exprs;
}

MappingTargetExprToDestExprMutator::MappingTargetExprToDestExprMutator(
Expand All @@ -83,6 +83,16 @@ void MappingTargetExprToDestExprMutator::Visit(const ir::Load* load, Expr* op) {
IRMutator::Visit(load, op);
}
}

void MappingTargetExprToDestExprMutator::Visit(const ir::For* for_node,
Expr* op) {
if (for_node == source_.ptr()) {
*op = dest_;
} else {
IRMutator::Visit(for_node, op);
}
}

void MappingTargetExprToDestExprMutator::Visit(const ir::Store* store,
Expr* op) {
if (store == source_.ptr()) {
Expand All @@ -91,6 +101,7 @@ void MappingTargetExprToDestExprMutator::Visit(const ir::Store* store,
IRMutator::Visit(store, op);
}
}

void MappingTargetExprToDestExprMutator::Visit(const ir::Reduce* reduce,
Expr* op) {
if (reduce == source_.ptr()) {
Expand Down Expand Up @@ -196,7 +207,7 @@ ExprSetFinder ExprSetFinder::operator*(ExprSetFinder x) const {
std::vector<ir::Expr> res;
for (const auto& r : rs) {
const auto& x_res = x.f_(r);
res.insert(res.begin(), x_res.begin(), x_res.end());
res.insert(res.end(), x_res.begin(), x_res.end());
}
return res;
};
Expand Down Expand Up @@ -246,6 +257,15 @@ ExprSetFinder ScheduleBlockRealizeNotRoot = FilterMaker(
},
"ScheduleBlockRealizeNotRoot");

ExprSetFinder ScheduleBlockRealizeIsRoot = FilterMaker(
[](const ir::Expr& e) -> bool {
return (e.As<ir::ScheduleBlockRealize>() &&
e.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name.find("root") != std::string::npos);
},
"ScheduleBlockRealizeIsRoot");

ExprSetFinder ScheduleBlockRealizeIsNotInit = FilterMaker(
[](const ir::Expr& e) -> bool {
return (e.As<ir::ScheduleBlockRealize>() &&
Expand Down Expand Up @@ -277,6 +297,12 @@ ExprSetFinder ChildScheduleBlockRealizes =
"ChildScheduleBlockRealizes") *
ScheduleBlockRealizeNotRoot;

ExprSetFinder ChildRootScheduleBlockRealizes =
Collector(
[](const ir::Expr* e) { return e->As<ir::ScheduleBlockRealize>(); },
"ChildScheduleBlockRealizes") *
ScheduleBlockRealizeIsRoot;

ExprSetFinder IsForIterVar(const ir::Var& var) {
return FilterMaker(
[var = var](const ir::Expr& e) -> bool {
Expand Down Expand Up @@ -304,7 +330,7 @@ ExprSetFinder ChildTensorLoads = Collector(

ExprSetFinder ChildTensorStores = Collector(
[](const ir::Expr* e) {
return e->As<ir::Load>() && e->As<ir::Store>()->is_addr_tensor();
return e->As<ir::Store>() && e->As<ir::Store>()->is_addr_tensor();
},
"ChildTensorStores");

Expand All @@ -324,8 +350,10 @@ ExprSetFinder FindFather(const ir::Expr& root) {
const auto& f = [&](const auto& child) -> ExprSet {
ExprSetFinder find_child =
Collector([child](const ir::Expr* e) { return *e == child; });
const auto& father_collector = Collector(
[&](const ir::Expr* current) { return !find_child(*current).empty(); });
const auto& father_collector = Collector([&](const ir::Expr* current) {
auto res = (*current != child) && !find_child(*current).empty();
return res;
});
return father_collector(root);
};
return ExprSetFinder(f, "FindFather");
Expand Down Expand Up @@ -373,6 +401,34 @@ ExprTransformer WrapForsTransformer(const std::vector<ir::Var>& vs) {
return ExprTransformer(f);
}

ExprTransformer UnsqueezeForTransformer(
const ExprSetFinderUtils::ExprSetFinder& followed_finder,
const ir::Var& to_append_var) {
const auto& f = [&](const ir::Expr& e) -> ir::Expr {
2742195759 marked this conversation as resolved.
Show resolved Hide resolved
auto copied_e = ir::ir_utils::IRCopy(e);
ir::Expr followed_expr = followed_finder.GetSingle(copied_e);
// (ExprSetFinderUtils::ChildFors *
// ExprSetFinderUtils::IsForIterVar(following_for_iter_var)).GetSingle(copied_e);
VLOG(6) << "UnsqueezeForTransformer: for insert after " << followed_expr;
if (followed_expr.As<ir::For>()) {
followed_expr.As<ir::For>()->body = ir::Block::Make({WrapForTransformer(
to_append_var)(followed_expr.As<ir::For>()->body)});
} else if (followed_expr.As<ir::ScheduleBlockRealize>()) {
const auto& schedule_block = followed_expr.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>();
schedule_block->body =
WrapForTransformer(to_append_var)(schedule_block->body);
} else {
VLOG(6) << "UnsqueezeForTransformer: only support insert after a (For / "
2742195759 marked this conversation as resolved.
Show resolved Hide resolved
"ScheduleBlockRealizer): "
<< followed_expr;
}
VLOG(6) << "UnsqueezeForTransformer: After changed: " << copied_e;
return copied_e;
};
return ExprTransformer(f);
}

ExprTransformer ChangeTensorLoadTransformer(const ir::Tensor& tensor,
const ir::Expr& dst_load) {
const auto& f = [&](const ir::Expr& e) -> ir::Expr {
Expand Down Expand Up @@ -420,6 +476,14 @@ ExprTransformer ChangeVarTransformer(const std::vector<ir::Var>& target_vars,
return ExprTransformer(f);
}

ExprTransformer ReplaceVarTransformer(const std::vector<ir::Var>& target_vars,
const std::vector<ir::Expr>& dest_expr) {
const auto& f = [=](const ir::Expr& e) -> ir::Expr {
return ComposeUtils::CopyedReplaceExpr(e, target_vars, dest_expr);
};
return ExprTransformer(f);
}

bool IsReduceBool(const ir::Expr& lhs, const ir::Expr& rhs) {
return lhs.type().is_bool() || rhs.type().is_bool();
}
Expand Down
13 changes: 12 additions & 1 deletion paddle/cinn/hlir/framework/pir/trivial_op_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ struct MappingTargetExprToDestExprMutator : public ir::IRMutator<> {
void Visit(const ir::Load* load, Expr* op) override;
void Visit(const ir::Store* store, Expr* op) override;
void Visit(const ir::Reduce* reduce, Expr* op) override;
void Visit(const ir::For* for_node, Expr* op) override;

private:
ir::Expr source_;
Expand Down Expand Up @@ -132,7 +133,7 @@ template <typename Teller>
ExprSetFinder Collector(Teller t, std::string name = "") {
return ExprSetFinder(
[=](const ir::Expr& x) -> ExprSet {
const auto& rs = cinn::ir::ir_utils::CollectIRNodesWithoutTensor(x, t);
const auto& rs = cinn::ir::ir_utils::CollectIRNodesInOrder(x, t);
return std::vector(rs.begin(), rs.end());
},
name);
Expand Down Expand Up @@ -170,6 +171,8 @@ extern ExprSetFinder ChildScheduleBlocks;

extern ExprSetFinder ChildScheduleBlockRealizes;

extern ExprSetFinder ChildRootScheduleBlockRealizes;

extern ExprSetFinder For2Min;

extern ExprSetFinder For2Max;
Expand Down Expand Up @@ -230,6 +233,14 @@ std::vector<ir::Var> CreateInnerBlockVars(
ExprTransformer ChangeVarTransformer(const std::vector<ir::Var>& target_vars,
const std::vector<ir::Var>& dest_vars);

ExprTransformer ReplaceVarTransformer(const std::vector<ir::Var>& target_vars,
const std::vector<ir::Expr>& dest_exprs);

// insert after followed_finder. only support For and ScheduleBlockRealizer
ExprTransformer UnsqueezeForTransformer(
const ExprSetFinderUtils::ExprSetFinder& followed_finder,
const ir::Var& to_append_var);

ExprTransformer SubstitudeByScheduleBlockRealize(const ir::Expr& realize);

ExprTransformer WrapScheduleRealizer(const std::vector<ir::Var>& block_vars,
Expand Down
Loading